Commit 7bf17aa2 authored by laibao's avatar laibao
Browse files

[BUGFIX] 修复 Step3p5 MTP 参数加载与 EAGLE lm_head 共享逻辑

fix:

- 修复 Step3p5 MTP 在加载 checkpoint 时对可选标量参数的识别逻辑,将 q/k/v zero_point 纳入 optional 参数集合,避免参数校验与加载不一致。

revert:

- 回退 EAGLE 中针对 MTP shared_head.head 强制复用 target lm_head 的逻辑,避免与当前 Step3p5 MTP 权重结构产生冲突。

目的:

- 降低 Step3p5 MTP 在权重加载阶段的兼容性问题,减少由于 lm_head 共享路径不一致导致的异常行为,方便后续排查和协作。
parent 824dde97
......@@ -275,9 +275,9 @@ class Step3p5MTP(nn.Module):
optional_params = {
name
for name, param in params_dict.items()
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale"))
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale",".k_zero_point", ".v_zero_point", ".q_zero_point"))
and getattr(param, "numel", lambda: 0)() == 1
and getattr(param, "requires_grad", False) is False
#and getattr(param, "requires_grad", False) is False
}
params_need_to_load -= optional_params
if params_need_to_load != loaded_params:
......
......@@ -1281,24 +1281,6 @@ class SpecDecodeBaseProposer:
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
# MTP models call compute_logits via shared_head.head (a
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
# If the checkpoint omits a copy of the lm_head weights at the
# MTP layer path, shared_head.head stays uninitialised and
# produces NaN logits. Always share it explicitly.
inner = getattr(self.model, "model", None)
layers = getattr(inner, "layers", None) if inner else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_language_model.lm_head
logger.info(
"Shared target model lm_head with MTP shared_head.head."
)
@torch.inference_mode()
def dummy_run(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment