Commit 84b9fe55 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-wm' into 'v0.15.1-dev'

[fix]修复GLM mtp精度问题

See merge request dcutoolkit/deeplearing/vllm!518
parents dfb597c8 44d4976d
...@@ -1267,6 +1267,24 @@ class SpecDecodeBaseProposer: ...@@ -1267,6 +1267,24 @@ class SpecDecodeBaseProposer:
del self.model.lm_head del self.model.lm_head
self.model.lm_head = target_language_model.lm_head self.model.lm_head = target_language_model.lm_head
# MTP models call compute_logits via shared_head.head (a
# ParallelLMHead inside each MTP layer), not self.model.lm_head.
# If the checkpoint omits a copy of the lm_head weights at the
# MTP layer path, shared_head.head stays uninitialised and
# produces NaN logits. Always share it explicitly.
inner = getattr(self.model, "model", None)
layers = getattr(inner, "layers", None) if inner else None
if layers is not None:
items = layers.values() if isinstance(layers, nn.ModuleDict) else layers
for layer in items:
sh = getattr(layer, "shared_head", None)
if sh is not None and hasattr(sh, "head"):
del sh.head
sh.head = target_language_model.lm_head
logger.info(
"Shared target model lm_head with MTP shared_head.head."
)
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment