Unverified Commit e84786aa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

consistent ignore keys + make private (#8737)

* consistent ignore keys + make private

* style

* - authorized_missing_keys    => _keys_to_ignore_on_load_missing
  - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected

* move public doc of private attributes to private comment
parent 49759c0c
......@@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if allow_missing_keys:
missing_keys.append(name)
continue
elif tf_model.authorized_missing_keys is not None:
elif tf_model._keys_to_ignore_on_load_missing is not None:
# authorized missing keys don't have to be loaded
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys):
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue
raise AttributeError("{} not found in PyTorch model".format(name))
......@@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
unexpected_keys = list(all_pytorch_weights)
if tf_model.authorized_missing_keys is not None:
for pat in tf_model.authorized_missing_keys:
if tf_model._keys_to_ignore_on_load_missing is not None:
for pat in tf_model._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model.authorized_unexpected_keys is not None:
for pat in tf_model.authorized_unexpected_keys:
if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
......
......@@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
authorized_unexpected_keys = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
......@@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model(model.dummy_inputs, training=False) # Make sure restore ops are run
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None:
for pat in cls.authorized_unexpected_keys:
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
......
......@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
model (useful for keys that aren't trained, but which are deterministic)
"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
authorized_unexpected_keys = None
keys_to_never_save = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
# trained, but which are deterministic)
_keys_to_ignore_on_save = None
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
......@@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
state_dict = model_to_save.state_dict()
# Handle the case where some state_dict keys shouldn't be saved
if self.keys_to_never_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
if self._keys_to_ignore_on_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
......@@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None:
for pat in cls.authorized_unexpected_keys:
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
......
......@@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig
base_model_prefix = "albert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights."""
......@@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module):
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
)
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......
......@@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
)
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
)
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......
......@@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel):
)
class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
def __init__(self, config: BartConfig):
super().__init__(config)
......
......@@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel):
)
class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
]
authorized_unexpected_keys = [
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
]
......
......@@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......@@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel):
)
class BertLMHeadModel(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
)
class BertForTokenClassification(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel):
)
class BertForQuestionAnswering(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......
......@@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
)
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
)
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_unexpected_keys = [r"pooler"]
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......
......@@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig
base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......
......@@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
config_class = DebertaConfig
base_model_prefix = "deberta"
authorized_missing_keys = ["position_ids"]
_keys_to_ignore_on_load_missing = ["position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......
......@@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.ctx_encoder.init_weights()
......@@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.question_encoder.init_weights()
......@@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig
load_tf_weights = None
base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self):
self.span_predictor.encoder.init_weights()
......
......@@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel):
config_class = ElectraConfig
load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra"
authorized_missing_keys = [r"position_ids"]
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
......
......@@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel):
)
class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model"
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
......
......@@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING,
)
class GPT2LMHeadModel(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
config_class = LayoutLMConfig
base_model_prefix = "layoutlm"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......
......@@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
authorized_missing_keys = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
......@@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel):
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
)
class LongformerForSequenceClassification(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module):
)
class LongformerForQuestionAnswering(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......@@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
)
class LongformerForTokenClassification(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
)
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
)
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
)
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......@@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
)
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
......
......@@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration):
"""
config_class = MarianConfig
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
keys_to_never_save = [
_keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
......
......@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
class TFMarianMTModel(TFBartForConditionalGeneration):
authorized_missing_keys = [
_keys_to_ignore_on_load_missing = [
r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment