Unverified Commit 2decec98 authored by SteadfastAsArt's avatar SteadfastAsArt Committed by GitHub
Browse files

[Transformers backend] Ignore MTP weights when num_nextn_predict_layers=0 (#34888)


Signed-off-by: default avatarSteadfastAsArt <695488173@qq.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 29b35477
...@@ -300,14 +300,26 @@ class Base( ...@@ -300,14 +300,26 @@ class Base(
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
new_module = child_module new_module = child_module
qual_name = maybe_prefix(prefix, child_name) qual_name = maybe_prefix(prefix, child_name)
# Populate Eagle3 attrs
if ( if (
isinstance(module, nn.ModuleList) isinstance(module, nn.ModuleList)
and len(module) == self.text_config.num_hidden_layers and len(module) == self.text_config.num_hidden_layers
): ):
# Populate Eagle3 attrs
self._target_class = type(child_module) self._target_class = type(child_module)
layer_name = qual_name.removeprefix("model.") layer_name = qual_name.removeprefix("model.")
self._layer_names[int(child_name)] = layer_name self._layer_names[int(child_name)] = layer_name
# MTP weights should not be loaded into the base model
num_hidden_layers = self.text_config.num_hidden_layers
names = (
"n_predict", # Override from SpeculativeConfig
"num_nextn_predict_layers", # Most models
"mtp_num_hidden_layers", # Qwen 3.5
)
n_predict = getattr_iter(self.text_config, names, 0)
for i in range(num_hidden_layers, num_hidden_layers + n_predict):
mtp_prefix = f"{prefix}.{i}."
if mtp_prefix not in self.ignore_unexpected_prefixes:
self.ignore_unexpected_prefixes.append(mtp_prefix)
# Replace modules as needed # Replace modules as needed
if isinstance(child_module, nn.Linear): if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name)) generator = (p for p in tp_plan if re.match(p, qual_name))
......
...@@ -311,8 +311,9 @@ class AutoWeightsLoader: ...@@ -311,8 +311,9 @@ class AutoWeightsLoader:
continue continue
named_parameters = module.named_parameters(recurse=True)
desc_param_keys = { desc_param_keys = {
base_prefix + k for k, _ in module.named_parameters(recurse=True) maybe_prefix(base_prefix, k) for k, _ in named_parameters
} }
msg = ( msg = (
f"There is no module or parameter named {prefix!r} " f"There is no module or parameter named {prefix!r} "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment