Unverified Commit bac2d29a authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Attempting to test automatically the `_keys_to_ignore`. (#20042)



* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d606d566
......@@ -1140,6 +1140,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
PEGASUS_START_DOCSTRING,
)
class PegasusModel(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusConfig):
super().__init__(config)
......@@ -1296,6 +1298,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
r"decoder.version",
r"lm_head.weight",
r"embed_positions.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config: PegasusConfig):
......@@ -1496,6 +1500,8 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
class PegasusForCausalLM(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -1375,6 +1375,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
PEGASUS_X_START_DOCSTRING,
)
class PegasusXModel(PegasusXPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: PegasusXConfig):
super().__init__(config)
......@@ -1522,6 +1524,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
r"decoder.version",
r"lm_head.weight",
r"embed_positions.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
def __init__(self, config: PegasusXConfig):
......
......@@ -1125,6 +1125,8 @@ class PLBartDecoder(PLBartPreTrainedModel):
PLBART_START_DOCSTRING,
)
class PLBartModel(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig):
super().__init__(config)
......@@ -1247,6 +1249,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
def __init__(self, config: PLBartConfig):
......@@ -1411,6 +1415,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
PLBART_START_DOCSTRING,
)
class PLBartForSequenceClassification(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = PLBartModel(config)
......@@ -1548,6 +1554,8 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
class PLBartForCausalLM(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -859,11 +859,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [
batch_size,
ngram_sequence_length,
hidden_size,
], (
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}"
)
......@@ -1774,6 +1770,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetModel(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
......@@ -1901,6 +1899,12 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"decoder.word_embeddings.weight",
"encoder.word_embeddings.weight",
"lm_head.weight",
]
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
self.prophetnet = ProphetNetModel(config)
......@@ -2111,6 +2115,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING,
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: ProphetNetConfig):
# set config for CLM
config = copy.deepcopy(config)
......
......@@ -1140,6 +1140,8 @@ class RealmBertModel(RealmPreTrainedModel):
REALM_START_DOCSTRING,
)
class RealmEmbedder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1368,6 +1370,8 @@ class RealmScorer(RealmPreTrainedModel):
REALM_START_DOCSTRING,
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
self.realm = RealmBertModel(self.config)
......
......@@ -2192,6 +2192,8 @@ class ReformerModel(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
......
......@@ -1051,6 +1051,8 @@ class RoCBertModel(RoCBertPreTrainedModel):
ROC_BERT_START_DOCSTRING,
)
class RoCBertForPreTraining(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1235,7 +1237,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
@add_start_docstrings("""RoCBert Model with a `language modeling` head on top.""", ROC_BERT_START_DOCSTRING)
class RoCBertForMaskedLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert
def __init__(self, config):
......@@ -1361,7 +1363,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
)
class RoCBertForCausalLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert
def __init__(self, config):
......
......@@ -954,6 +954,8 @@ class RoFormerModel(RoFormerPreTrainedModel):
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
class RoFormerForMaskedLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1055,8 +1057,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
"""RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
)
class RoFormerForCausalLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1256,6 +1256,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
r"decoder.version",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
r"lm_head.weight",
]
_keys_to_ignore_on_save = [
r"model.encoder.embed_positions.weights",
......
......@@ -745,6 +745,8 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel):
SPEECH_TO_TEXT_2_START_DOCSTRING,
)
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -648,7 +648,11 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top.""", SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
def __init__(self, config):
super().__init__(config)
......
......@@ -1758,9 +1758,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [
r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
......
......@@ -1004,6 +1004,7 @@ class TapasModel(TapasPreTrainedModel):
@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
class TapasForMaskedLM(TapasPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
config_class = TapasConfig
base_model_prefix = "tapas"
......
......@@ -1006,6 +1006,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
TRANSFO_XL_START_DOCSTRING,
)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = TransfoXLModel(config)
......
......@@ -785,6 +785,8 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
TROCR_START_DOCSTRING,
)
class TrOCRForCausalLM(TrOCRPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -890,6 +890,8 @@ class ViltPooler(nn.Module):
VILT_START_DOCSTRING,
)
class ViltForMaskedLM(ViltPreTrainedModel):
_keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -871,6 +871,8 @@ class VisualBertModel(VisualBertPreTrainedModel):
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForPreTraining(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1459,6 +1461,8 @@ class VisualBertRegionToPhraseAttention(nn.Module):
VISUAL_BERT_START_DOCSTRING,
)
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -825,6 +825,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"model.embed_positions.weights",
r"embed_positions.weights",
r"lm_head.weight",
]
_keys_to_ignore_on_save = [
......
......@@ -673,6 +673,8 @@ class XLMPredLayer(nn.Module):
XLM_START_DOCSTRING,
)
class XLMWithLMHeadModel(XLMPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = XLMModel(config)
......
......@@ -876,11 +876,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [
batch_size,
ngram_sequence_length,
hidden_size,
], (
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}"
)
......@@ -1798,6 +1794,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
def __init__(self, config: XLMProphetNetConfig):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
......@@ -1926,6 +1924,12 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"decoder.word_embeddings.weight",
"encoder.word_embeddings.weight",
"lm_head.weight",
]
def __init__(self, config: XLMProphetNetConfig):
super().__init__(config)
self.prophetnet = XLMProphetNetModel(config)
......@@ -2139,6 +2143,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
)
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig):
# set config for CLM
config = copy.deepcopy(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment