Unverified Commit 1471857f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

update decoder_vocab_size when resizing embeds (#16700)

parent 5e686757
...@@ -1280,11 +1280,9 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1280,11 +1280,9 @@ class MarianMTModel(MarianPreTrainedModel):
super().__init__(config) super().__init__(config)
self.model = MarianModel(config) self.model = MarianModel(config)
self.target_vocab_size = ( target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
) self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)
self.register_buffer("final_logits_bias", torch.zeros((1, self.target_vocab_size)))
self.lm_head = nn.Linear(config.d_model, self.target_vocab_size, bias=False)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1306,6 +1304,10 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1306,6 +1304,10 @@ class MarianMTModel(MarianPreTrainedModel):
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings) self.set_input_embeddings(new_embeddings)
# update config.decoder_vocab_size if embeddings are tied
if self.config.share_encoder_decoder_embeddings:
self.config.decoder_vocab_size = new_num_tokens
# if word embeddings are not tied, make sure that lm head is resized as well # if word embeddings are not tied, make sure that lm head is resized as well
if ( if (
self.config.share_encoder_decoder_embeddings self.config.share_encoder_decoder_embeddings
...@@ -1451,7 +1453,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1451,7 +1453,7 @@ class MarianMTModel(MarianPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(lm_logits.view(-1, self.target_vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment