Unverified Commit e84786aa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

consistent ignore keys + make private (#8737)

* consistent ignore keys + make private

* style

* - authorized_missing_keys    => _keys_to_ignore_on_load_missing
  - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected

* move public doc of private attributes to private comment
parent 49759c0c
...@@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration): ...@@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
""" """
model_type = "mbart" model_type = "mbart"
config_class = MBartConfig config_class = MBartConfig
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
...@@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel): ...@@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert" base_model_prefix = "mobilebert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
...@@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) @add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class MobileBertForMaskedLM(MobileBertPreTrainedModel): class MobileBertForMaskedLM(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ...@@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
) )
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ...@@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
) )
class MobileBertForTokenClassification(MobileBertPreTrainedModel): class MobileBertForTokenClassification(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): ...@@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) @add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ...@@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
) )
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss): class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ...@@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
) )
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss): class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
......
...@@ -42,12 +42,12 @@ class MT5Model(T5Model): ...@@ -42,12 +42,12 @@ class MT5Model(T5Model):
""" """
model_type = "mt5" model_type = "mt5"
config_class = MT5Config config_class = MT5Config
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
] ]
...@@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration): ...@@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
model_type = "mt5" model_type = "mt5"
config_class = MT5Config config_class = MT5Config
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"lm_head\.weight", r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
] ]
...@@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
load_tf_weights = load_tf_weights_in_openai_gpt load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
......
...@@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration): ...@@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
""" """
# All the code is in src/transformers/models/bart/modeling_bart.py # All the code is in src/transformers/models/bart/modeling_bart.py
config_class = PegasusConfig config_class = PegasusConfig
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", r"final_logits_bias",
r"encoder\.version", r"encoder\.version",
r"decoder\.version", r"decoder\.version",
"model.encoder.embed_positions", "model.encoder.embed_positions",
"model.decoder.embed_positions", "model.decoder.embed_positions",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
...@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__) ...@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING) @add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration): class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", r"final_logits_bias",
r"model.encoder.embed_positions.weight", r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight", r"model.decoder.embed_positions.weight",
......
...@@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel): ...@@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel):
""" """
config_class = RagConfig config_class = RagConfig
base_model_prefix = "rag" base_model_prefix = "rag"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
@classmethod @classmethod
def from_pretrained_question_encoder_generator( def from_pretrained_question_encoder_generator(
......
...@@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel):
""" """
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
...@@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel):
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
) )
class RobertaForCausalLM(RobertaPreTrainedModel): class RobertaForCausalLM(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class RobertaForMaskedLM(RobertaPreTrainedModel): class RobertaForMaskedLM(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module): ...@@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForSequenceClassification(RobertaPreTrainedModel): class RobertaForSequenceClassification(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): ...@@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForMultipleChoice(RobertaPreTrainedModel): class RobertaForMultipleChoice(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel): ...@@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForTokenClassification(RobertaPreTrainedModel): class RobertaForTokenClassification(RobertaPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module): ...@@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module):
ROBERTA_START_DOCSTRING, ROBERTA_START_DOCSTRING,
) )
class RobertaForQuestionAnswering(RobertaPreTrainedModel): class RobertaForQuestionAnswering(RobertaPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer): ...@@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
) )
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss): class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ...@@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
) )
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss): class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ...@@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
) )
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss): class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
......
...@@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel): ...@@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
config_class = SqueezeBertConfig config_class = SqueezeBertConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
...@@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel): ...@@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING) @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
authorized_missing_keys = [r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r"""
T5_START_DOCSTRING, T5_START_DOCSTRING,
) )
class T5Model(T5PreTrainedModel): class T5Model(T5PreTrainedModel):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
...@@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel): ...@@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel): class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"lm_head\.weight", r"lm_head\.weight",
......
...@@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r""" ...@@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r"""
XLM_START_DOCSTRING, XLM_START_DOCSTRING,
) )
class XLMModel(XLMPreTrainedModel): class XLMModel(XLMPreTrainedModel):
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}" base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
......
...@@ -135,17 +135,17 @@ class ModelTesterMixin: ...@@ -135,17 +135,17 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def test_save_load_keys_to_never_save(self): def test_save_load__keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
keys_to_never_save = getattr(model, "keys_to_never_save", None) _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
if keys_to_never_save is None: if _keys_to_ignore_on_save is None:
continue continue
# check the keys are in the original state_dict # check the keys are in the original state_dict
for k in keys_to_never_save: for k in _keys_to_ignore_on_save:
self.assertIn(k, model.state_dict()) self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model # check that certain keys didn't get saved with the model
...@@ -153,7 +153,7 @@ class ModelTesterMixin: ...@@ -153,7 +153,7 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME) output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file) state_dict_saved = torch.load(output_model_file)
for k in keys_to_never_save: for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved) self.assertNotIn(k, state_dict_saved)
def test_initialization(self): def test_initialization(self):
......
...@@ -60,7 +60,7 @@ class ModelTester: ...@@ -60,7 +60,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase): class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MarianMTModel,) if is_torch_available() else () all_model_classes = (MarianMTModel,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = ModelTester(self)
......
...@@ -47,7 +47,7 @@ class ModelTester: ...@@ -47,7 +47,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase): class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = ModelTester(self)
......
...@@ -43,7 +43,7 @@ class ModelTester: ...@@ -43,7 +43,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase): class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = ModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment