Unverified Commit 561b77a0 authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

[Bugfix] Fix the lm_head in gpt_bigcode in lora mode (#6357)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
parent abd4030d
......@@ -272,12 +272,6 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}
# LoRA specific attributes
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -330,8 +324,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = None
if self.config.tie_word_embeddings:
skip_prefixes = ["lm_head."]
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]),
skip_prefixes=skip_prefixes,
)
return loader.load_weights(weights)
\ No newline at end of file
return loader.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment