Unverified Commit e84786aa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

consistent ignore keys + make private (#8737)

* consistent ignore keys + make private

* style

* - authorized_missing_keys    => _keys_to_ignore_on_load_missing
  - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected

* move public doc of private attributes to private comment
parent 49759c0c
......@@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
"""
model_type = "mbart"
config_class = MBartConfig
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
......@@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
load_tf_weights = load_tf_weights_in_mobilebert
base_model_prefix = "mobilebert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......@@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class MobileBertForMaskedLM(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
)
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
)
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
)
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
)
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......
......@@ -42,12 +42,12 @@ class MT5Model(T5Model):
"""
model_type = "mt5"
config_class = MT5Config
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]
......@@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
model_type = "mt5"
config_class = MT5Config
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]
......@@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
config_class = OpenAIGPTConfig
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights."""
......
......@@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
"""
# All the code is in src/transformers/models/bart/modeling_bart.py
config_class = PegasusConfig
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
"model.encoder.embed_positions",
"model.decoder.embed_positions",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
......@@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight",
......
......@@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel):
"""
config_class = RagConfig
base_model_prefix = "rag"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
@classmethod
def from_pretrained_question_encoder_generator(
......
......@@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel):
"""
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config, add_pooling_layer=True):
......@@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel):
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
)
class RobertaForCausalLM(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class RobertaForMaskedLM(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module):
ROBERTA_START_DOCSTRING,
)
class RobertaForSequenceClassification(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
ROBERTA_START_DOCSTRING,
)
class RobertaForMultipleChoice(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
ROBERTA_START_DOCSTRING,
)
class RobertaForTokenClassification(RobertaPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module):
ROBERTA_START_DOCSTRING,
)
class RobertaForQuestionAnswering(RobertaPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......
......@@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
)
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
)
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
)
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......
......@@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
config_class = SqueezeBertConfig
base_model_prefix = "transformer"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......@@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
authorized_missing_keys = [r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r"""
T5_START_DOCSTRING,
)
class T5Model(T5PreTrainedModel):
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
......@@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
......
......@@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r"""
XLM_START_DOCSTRING,
)
class XLMModel(XLMPreTrainedModel):
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
......
......@@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
config_class = {{cookiecutter.camelcase_modelname}}Config
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......
......@@ -135,17 +135,17 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_save_load_keys_to_never_save(self):
def test_save_load__keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
keys_to_never_save = getattr(model, "keys_to_never_save", None)
if keys_to_never_save is None:
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
if _keys_to_ignore_on_save is None:
continue
# check the keys are in the original state_dict
for k in keys_to_never_save:
for k in _keys_to_ignore_on_save:
self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model
......@@ -153,7 +153,7 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file)
for k in keys_to_never_save:
for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved)
def test_initialization(self):
......
......@@ -60,7 +60,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MarianMTModel,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self):
self.model_tester = ModelTester(self)
......
......@@ -47,7 +47,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self):
self.model_tester = ModelTester(self)
......
......@@ -43,7 +43,7 @@ class ModelTester:
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
def setUp(self):
self.model_tester = ModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment