Unverified Commit ada268fd authored by Kyungmin Lee's avatar Kyungmin Lee Committed by GitHub
Browse files

fix: EXAONE when using tie_word_embeddings (#5759)

parent cfe48c59
...@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module): ...@@ -307,9 +307,14 @@ class ExaoneForCausalLM(nn.Module):
self.transformer = ExaoneModel( self.transformer = ExaoneModel(
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix) config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
) )
self.lm_head = ParallelLMHead( if self.config.tie_word_embeddings:
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) self.lm_head = self.transformer.wte
) else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment