Unverified Commit 07c11cf4 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix lm_head weights tying with lora for llama (#9227)

parent f3a507f1
...@@ -443,7 +443,7 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -443,7 +443,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super().__init__(num_embeddings, embedding_dim, params_dtype, super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config, org_num_embeddings, padding_size, quant_config,
prefix) prefix)
self.quant_config = quant_config
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
...@@ -455,6 +455,15 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -455,6 +455,15 @@ class ParallelLMHead(VocabParallelEmbedding):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
"""Tie the weights with word embeddings."""
# GGUF quantized embed_tokens.
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
self.weight = embed_tokens.weight
return self
def forward(self, input_): def forward(self, input_):
del input_ del input_
raise RuntimeError("LMHead's weights should be used in the sampler.") raise RuntimeError("LMHead's weights should be used in the sampler.")
...@@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config=quant_config, quant_config=quant_config,
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment