"EDK2/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "1d034f0a24ed466d5942689540fcbdc7ec262d05"
Unverified Commit e84786aa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

consistent ignore keys + make private (#8737)

* consistent ignore keys + make private

* style

* - authorized_missing_keys    => _keys_to_ignore_on_load_missing
  - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected

* move public doc of private attributes to private comment
parent 49759c0c
...@@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if allow_missing_keys: if allow_missing_keys:
missing_keys.append(name) missing_keys.append(name)
continue continue
elif tf_model.authorized_missing_keys is not None: elif tf_model._keys_to_ignore_on_load_missing is not None:
# authorized missing keys don't have to be loaded # authorized missing keys don't have to be loaded
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys): if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue continue
raise AttributeError("{} not found in PyTorch model".format(name)) raise AttributeError("{} not found in PyTorch model".format(name))
...@@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
unexpected_keys = list(all_pytorch_weights) unexpected_keys = list(all_pytorch_weights)
if tf_model.authorized_missing_keys is not None: if tf_model._keys_to_ignore_on_load_missing is not None:
for pat in tf_model.authorized_missing_keys: for pat in tf_model._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model.authorized_unexpected_keys is not None: if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model.authorized_unexpected_keys: for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
......
...@@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None # a list of re pattern of tensor names to ignore from the model when loading the model weights
authorized_unexpected_keys = None # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
...@@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
model(model.dummy_inputs, training=False) # Make sure restore ops are run model(model.dummy_inputs, training=False) # Make sure restore ops are run
if cls.authorized_missing_keys is not None: if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls.authorized_missing_keys: for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls.authorized_unexpected_keys: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
......
...@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
model (useful for keys that aren't trained, but which are deterministic)
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None # a list of re pattern of tensor names to ignore from the model when loading the model weights
authorized_unexpected_keys = None # (and avoid unnecessary warnings).
keys_to_never_save = None _keys_to_ignore_on_load_missing = None
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
# trained, but which are deterministic)
_keys_to_ignore_on_save = None
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
...@@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
# Handle the case where some state_dict keys shouldn't be saved # Handle the case where some state_dict keys shouldn't be saved
if self.keys_to_never_save is not None: if self._keys_to_ignore_on_save is not None:
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save} state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
...@@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Some models may have keys that are not in the state by design, removing them before needlessly warning # Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user. # the user.
if cls.authorized_missing_keys is not None: if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls.authorized_missing_keys: for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls.authorized_unexpected_keys: for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
......
...@@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel): ...@@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig config_class = AlbertConfig
base_model_prefix = "albert" base_model_prefix = "albert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
...@@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module): ...@@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module):
) )
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
) )
class AlbertForTokenClassification(AlbertPreTrainedModel): class AlbertForTokenClassification(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
) )
class AlbertForQuestionAnswering(AlbertPreTrainedModel): class AlbertForQuestionAnswering(AlbertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer): ...@@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING) @add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ...@@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
) )
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss): class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ...@@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
) )
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss): class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
......
...@@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel): ...@@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel):
) )
class BartForConditionalGeneration(PretrainedBartModel): class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel):
) )
class TFBartForConditionalGeneration(TFPretrainedBartModel): class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", r"final_logits_bias",
] ]
authorized_unexpected_keys = [ _keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight", r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
] ]
......
...@@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
...@@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel):
) )
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
) )
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel):
) )
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): ...@@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): ...@@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
) )
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ...@@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
) )
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
......
...@@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig config_class = BertGenerationConfig
base_model_prefix = "bert" base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
......
...@@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel): ...@@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
config_class = DebertaConfig config_class = DebertaConfig
base_model_prefix = "deberta" base_model_prefix = "deberta"
authorized_missing_keys = ["position_ids"] _keys_to_ignore_on_load_missing = ["position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
......
...@@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel): ...@@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "ctx_encoder" base_model_prefix = "ctx_encoder"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.ctx_encoder.init_weights() self.ctx_encoder.init_weights()
...@@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): ...@@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "question_encoder" base_model_prefix = "question_encoder"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.question_encoder.init_weights() self.question_encoder.init_weights()
...@@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel): ...@@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
config_class = DPRConfig config_class = DPRConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "span_predictor" base_model_prefix = "span_predictor"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def init_weights(self): def init_weights(self):
self.span_predictor.encoder.init_weights() self.span_predictor.encoder.init_weights()
......
...@@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel):
config_class = ElectraConfig config_class = ElectraConfig
load_tf_weights = load_tf_weights_in_electra load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra" base_model_prefix = "electra"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel):
) )
class FSMTForConditionalGeneration(PretrainedFSMTModel): class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
......
...@@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2ForSequenceClassification(GPT2PreTrainedModel): class GPT2ForSequenceClassification(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
config_class = LayoutLMConfig config_class = LayoutLMConfig
base_model_prefix = "layoutlm" base_model_prefix = "layoutlm"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
......
...@@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
authorized_missing_keys = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
...@@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel):
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) @add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(LongformerPreTrainedModel): class LongformerForMaskedLM(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
) )
class LongformerForSequenceClassification(LongformerPreTrainedModel): class LongformerForSequenceClassification(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module): ...@@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module):
) )
class LongformerForQuestionAnswering(LongformerPreTrainedModel): class LongformerForQuestionAnswering(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ...@@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
) )
class LongformerForTokenClassification(LongformerPreTrainedModel): class LongformerForTokenClassification(LongformerPreTrainedModel):
authorized_unexpected_keys = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ...@@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
) )
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
) )
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer): ...@@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
) )
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss): class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
...@@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
) )
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss): class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"] _keys_to_ignore_on_load_missing = [r"pooler"]
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
......
...@@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration):
""" """
config_class = MarianConfig config_class = MarianConfig
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
keys_to_never_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
......
...@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) ...@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__)
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING) @add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
class TFMarianMTModel(TFBartForConditionalGeneration): class TFMarianMTModel(TFBartForConditionalGeneration):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [
r"model.encoder.embed_positions.weight", r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight", r"model.decoder.embed_positions.weight",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment