Unverified Commit a3339d8c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Bug: Fix weight loader error when LM head weights are tied (#3766)

parent 14d90617
...@@ -458,6 +458,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -458,6 +458,8 @@ class LlamaForCausalLM(nn.Module):
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
# Handle FP8 kv-scale remapping # Handle FP8 kv-scale remapping
if "scale" in name: if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict) name = maybe_remap_kv_scale_name(name, params_dict)
......
...@@ -339,6 +339,8 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -339,6 +339,8 @@ class MiniCPMForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
......
...@@ -603,6 +603,8 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -603,6 +603,8 @@ class MiniCPM3ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
......
...@@ -325,6 +325,8 @@ class OlmoForCausalLM(nn.Module): ...@@ -325,6 +325,8 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -433,6 +433,8 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -433,6 +433,8 @@ class Phi3SmallForCausalLM(nn.Module):
continue continue
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
......
...@@ -377,6 +377,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -377,6 +377,8 @@ class Qwen2ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
......
...@@ -586,6 +586,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -586,6 +586,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
......
...@@ -486,6 +486,8 @@ class TorchNativeLlamaForCausalLM(nn.Module): ...@@ -486,6 +486,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment