Unverified Commit 630dd9e0 authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix][Model] Skip loading lm_head weights if using tie_word_embeddings (#6758)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
parent 23993a79
...@@ -998,6 +998,13 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision): ...@@ -998,6 +998,13 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
use_default_weight_loading = False use_default_weight_loading = False
if "vqmodel" in name: if "vqmodel" in name:
if self.model.vqmodel is not None: if self.model.vqmodel is not None:
......
...@@ -469,6 +469,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -469,6 +469,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name): if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization # Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name] param = params_dict[scale_name]
......
...@@ -514,7 +514,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -514,7 +514,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -343,6 +343,11 @@ class OlmoForCausalLM(nn.Module): ...@@ -343,6 +343,11 @@ class OlmoForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment