Unverified Commit bac2d29a authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Attempting to test automatically the `_keys_to_ignore`. (#20042)



* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d606d566
......@@ -2421,8 +2421,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
_prefix = f"{prefix}."
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
......@@ -2641,13 +2642,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()]))
module_keys = module_keys.union(
set([".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()])
)
retrieved_modules = []
# retrieve all modules that has at least one missing weight name
for name, module in self.named_modules():
if remove_prefix:
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
_prefix = f"{self.base_model_prefix}."
name = name[len(_prefix) :] if name.startswith(_prefix) else name
elif add_prefix:
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
......
......@@ -762,6 +762,12 @@ class AlbertModel(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config: AlbertConfig):
super().__init__(config)
......@@ -910,6 +916,11 @@ class AlbertSOPHead(nn.Module):
class AlbertForMaskedLM(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config):
super().__init__(config)
......
......@@ -1153,6 +1153,8 @@ class BartDecoder(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig):
super().__init__(config)
......@@ -1281,7 +1283,12 @@ class BartModel(BartPretrainedModel):
)
class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config: BartConfig):
super().__init__(config)
......@@ -1451,6 +1458,8 @@ class BartForConditionalGeneration(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
......@@ -1578,6 +1587,8 @@ class BartForSequenceClassification(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1714,6 +1725,8 @@ class BartDecoderWrapper(BartPretrainedModel):
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -1047,6 +1047,8 @@ class BertModel(BertPreTrainedModel):
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1153,7 +1155,7 @@ class BertForPreTraining(BertPreTrainedModel):
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1288,7 +1290,7 @@ class BertLMHeadModel(BertPreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -855,6 +855,8 @@ class BertGenerationOnlyLMHead(nn.Module):
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
......
......@@ -2262,6 +2262,8 @@ class BigBirdModel(BigBirdPreTrainedModel):
class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -2366,6 +2368,8 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -2508,8 +2512,12 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
]
def __init__(self, config):
super().__init__(config)
......
......@@ -2350,6 +2350,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_bart.BartModel with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
......@@ -2480,7 +2482,12 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
......@@ -2651,6 +2658,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BigBirdPegasusModel(config)
......@@ -2779,6 +2788,8 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -2910,6 +2921,8 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -1087,6 +1087,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
BLENDERBOT_START_DOCSTRING,
)
class BlenderbotModel(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
......@@ -1231,6 +1233,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
def __init__(self, config: BlenderbotConfig):
......@@ -1420,6 +1424,8 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -1081,6 +1081,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
BLENDERBOT_SMALL_START_DOCSTRING,
)
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
......@@ -1213,6 +1215,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
r"encoder.version",
r"decoder.version",
r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, config: BlenderbotSmallConfig):
......@@ -1387,6 +1391,8 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
config.is_decoder = True
......
......@@ -763,6 +763,8 @@ CONVBERT_INPUTS_DOCSTRING = r"""
CONVBERT_START_DOCSTRING,
)
class ConvBertModel(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
self.embeddings = ConvBertEmbeddings(config)
......@@ -877,6 +879,8 @@ class ConvBertGeneratorPredictions(nn.Module):
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class ConvBertForMaskedLM(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -987,6 +991,8 @@ class ConvBertClassificationHead(nn.Module):
CONVBERT_START_DOCSTRING,
)
class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1083,6 +1089,8 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING,
)
class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
......@@ -1177,6 +1185,8 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING,
)
class ConvBertForTokenClassification(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
......@@ -1259,6 +1269,8 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel):
CONVBERT_START_DOCSTRING,
)
class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
def __init__(self, config):
super().__init__(config)
......
......@@ -509,6 +509,8 @@ class CTRLModel(CTRLPreTrainedModel):
CTRL_START_DOCSTRING,
)
class CTRLLMHeadModel(CTRLPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = CTRLModel(config)
......
......@@ -1038,7 +1038,7 @@ class DebertaModel(DebertaPreTrainedModel):
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1139,7 +1139,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1788,6 +1788,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
DEFORMABLE_DETR_START_DOCSTRING,
)
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"]
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
......
......@@ -579,6 +579,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
def __init__(self, config: PretrainedConfig):
super().__init__(config)
......
......@@ -1161,6 +1161,8 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
ELECTRA_START_DOCSTRING,
)
class ElectraForMaskedLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1530,6 +1532,8 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING
)
class ElectraForCausalLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -977,6 +977,8 @@ class ErnieModel(ErniePreTrainedModel):
ERNIE_START_DOCSTRING,
)
class ErnieForPreTraining(ErniePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config):
super().__init__(config)
......@@ -1087,7 +1089,7 @@ class ErnieForPreTraining(ErniePreTrainedModel):
)
class ErnieForCausalLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
def __init__(self, config):
......@@ -1228,7 +1230,7 @@ class ErnieForCausalLM(ErniePreTrainedModel):
@add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING)
class ErnieForMaskedLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config):
......
......@@ -896,7 +896,7 @@ class EsmModel(EsmPreTrainedModel):
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
class EsmForMaskedLM(EsmPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_missing = [r"position_ids", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
......
......@@ -657,6 +657,8 @@ class FlaubertModel(FlaubertPreTrainedModel):
)
# Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
class FlaubertWithLMHeadModel(FlaubertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = FlaubertModel(config)
......
......@@ -1729,6 +1729,14 @@ class FlavaGlobalContrastiveHead(nn.Module):
FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
)
class FlavaForPreTraining(FlavaPreTrainedModel):
# Those are linked to xxx.bias
_keys_to_ignore_on_load_missing = [
"mmm_text_head.decoder.bias",
"mmm_image_head.decoder.bias",
"mlm_head.decoder.bias",
"mim_head.decoder.bias",
]
def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
super().__init__(config)
self.flava = FlavaModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment