Unverified Commit 44f64132 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

remove final_logits_bias (#10606)

parent 6f52fce6
...@@ -1153,7 +1153,6 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1153,7 +1153,6 @@ class M2M100Model(M2M100PreTrainedModel):
class M2M100ForConditionalGeneration(M2M100PreTrainedModel): class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version", r"encoder\.version",
r"decoder\.version", r"decoder\.version",
r"lm_head\.weight", r"lm_head\.weight",
...@@ -1168,7 +1167,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1168,7 +1167,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
super().__init__(config) super().__init__(config)
self.model = M2M100Model(config) self.model = M2M100Model(config)
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights() self.init_weights()
...@@ -1181,18 +1179,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1181,18 +1179,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens) new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
...@@ -1266,7 +1254,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1266,7 +1254,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias lm_logits = self.lm_head(outputs[0])
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment