Unverified Commit c6d5e565 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix naming (#10095)

parent 4ed76377
...@@ -1105,7 +1105,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1105,7 +1105,6 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"pooler", r"pooler",
r"seq_relationship___cls", r"seq_relationship___cls",
r"predictions___cls",
r"cls.seq_relationship", r"cls.seq_relationship",
] ]
...@@ -1113,10 +1112,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1113,10 +1112,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
def get_lm_head(self): def get_lm_head(self):
return self.mlm.predictions return self.predictions.predictions
def get_prefix_bias_name(self): def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
...@@ -1179,7 +1178,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1179,7 +1178,7 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
training=inputs["training"], training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=inputs["training"]) prediction_scores = self.predictions(sequence_output, training=inputs["training"])
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], prediction_scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment