Unverified Commit 2562e027 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[MTP] Validate that MTP weights are actually loaded (#35548)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent fd68cd13
...@@ -415,6 +415,26 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -415,6 +415,26 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
# Validate that weights were loaded for each expected MTP layer.
loaded_layers: set[int] = set()
for param_name in loaded_params:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
if spec_layer is not None:
loaded_layers.add(spec_layer)
for layer_idx in range(
self.model.mtp_start_layer_idx,
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
):
if layer_idx not in loaded_layers:
raise ValueError(
f"MTP speculative decoding layer {layer_idx} weights "
f"missing from checkpoint. The checkpoint may have "
f"been quantized without including the MTP layers. "
f"Use a checkpoint that includes MTP layer weights, "
f"or disable speculative decoding."
)
return loaded_params return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment