Unverified Commit f23fb5a7 authored by RickyChen / 陳昭儒's avatar RickyChen / 陳昭儒 Committed by GitHub
Browse files

[Bugfix] Support HF sharded weights for Mistral3/Pixtral models (#32673)


Signed-off-by: default avatarricky-chaoju <ricky.chen@infinirc.com>
Signed-off-by: default avatarvllm-dev <ricky.chen@infinirc.com>
parent 360aa93f
...@@ -502,10 +502,12 @@ class PixtralForConditionalGeneration( ...@@ -502,10 +502,12 @@ class PixtralForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder") return weight[0].startswith(("vision_encoder", "vision_tower"))
def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter") return weight[0].startswith(
("vision_language_adapter", "multi_modal_projector")
)
def is_patch_merger(weight: tuple[str, torch.Tensor]): def is_patch_merger(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("patch_merger") return weight[0].startswith("patch_merger")
...@@ -543,9 +545,10 @@ class PixtralForConditionalGeneration( ...@@ -543,9 +545,10 @@ class PixtralForConditionalGeneration(
continue continue
# Load vision encoder weights directly # Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:]) trimmed_name = ".".join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name] param = vision_encoder_dict.get(trimmed_name)
with torch.no_grad(): if param is not None:
default_weight_loader(param, w) with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)): elif is_patch_merger((name, w)):
if self.patch_merger is None: if self.patch_merger is None:
continue continue
...@@ -567,12 +570,15 @@ class PixtralForConditionalGeneration( ...@@ -567,12 +570,15 @@ class PixtralForConditionalGeneration(
continue continue
# Load vision-language adapter weights directly # Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:]) trimmed_name = ".".join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name] param = vision_lang_adapter_dict.get(trimmed_name)
with torch.no_grad(): if param is not None:
default_weight_loader(param, w) with torch.no_grad():
default_weight_loader(param, w)
else: else:
# LLM weights: yield them to be loaded # LLM weights: yield them to be loaded
# by language_model.load_weights # by language_model.load_weights
# Strip "language_model." prefix if present (HF sharded format)
name = name.removeprefix("language_model.")
yield (name, w) yield (name, w)
# Now we call the language model load with the generator # Now we call the language model load with the generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment