"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "c68c7b403dce632dbbbb6d2482ea86fe7bf53d51"
Unverified Commit d3ade61e authored by wuyaoxuehun's avatar wuyaoxuehun Committed by GitHub
Browse files

[Model] fix glm4_moe_mtp load weights with GLM-4.6 checkpoint. (#27597)


Signed-off-by: default avatarwuao.scotty <wuao.scotty@bytedance.com>
Co-authored-by: default avatarwuao.scotty <wuao.scotty@bytedance.com>
parent 1761dea1
...@@ -256,11 +256,18 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): ...@@ -256,11 +256,18 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
spec_layer = self.model.mtp_start_layer_idx
for name, loaded_weight in weights: for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if name == "lm_head.weight":
if spec_layer is None: name = f"model.layers.{spec_layer}.shard_head.head.weight"
continue elif name == "model.embed_tokens.weight":
name = self._rewrite_spec_layer_name(spec_layer, name) # This name is same with local model, rewriting is not needed.
pass
else:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: if weight_name not in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment