"vscode:/vscode.git/clone" did not exist on "42fed15c81bbba2af8e4dd9f03930ce011eafa7e"
Unverified Commit bac2d29a authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Attempting to test automatically the `_keys_to_ignore`. (#20042)



* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d606d566
...@@ -1140,6 +1140,8 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1140,6 +1140,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
PEGASUS_START_DOCSTRING, PEGASUS_START_DOCSTRING,
) )
class PegasusModel(PegasusPreTrainedModel): class PegasusModel(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusConfig): def __init__(self, config: PegasusConfig):
super().__init__(config) super().__init__(config)
...@@ -1296,6 +1298,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1296,6 +1298,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
r"embed_positions.weight", r"embed_positions.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
] ]
def __init__(self, config: PegasusConfig): def __init__(self, config: PegasusConfig):
...@@ -1496,6 +1500,8 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel): ...@@ -1496,6 +1500,8 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
class PegasusForCausalLM(PegasusPreTrainedModel): class PegasusForCausalLM(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -1375,6 +1375,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1375,6 +1375,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
PEGASUS_X_START_DOCSTRING, PEGASUS_X_START_DOCSTRING,
) )
class PegasusXModel(PegasusXPreTrainedModel): class PegasusXModel(PegasusXPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: PegasusXConfig): def __init__(self, config: PegasusXConfig):
super().__init__(config) super().__init__(config)
...@@ -1522,6 +1524,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): ...@@ -1522,6 +1524,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
r"embed_positions.weight", r"embed_positions.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
] ]
def __init__(self, config: PegasusXConfig): def __init__(self, config: PegasusXConfig):
......
...@@ -1125,6 +1125,8 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1125,6 +1125,8 @@ class PLBartDecoder(PLBartPreTrainedModel):
PLBART_START_DOCSTRING, PLBART_START_DOCSTRING,
) )
class PLBartModel(PLBartPreTrainedModel): class PLBartModel(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig): def __init__(self, config: PLBartConfig):
super().__init__(config) super().__init__(config)
...@@ -1247,6 +1249,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1247,6 +1249,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
r"encoder.version", r"encoder.version",
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
] ]
def __init__(self, config: PLBartConfig): def __init__(self, config: PLBartConfig):
...@@ -1411,6 +1415,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1411,6 +1415,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
PLBART_START_DOCSTRING, PLBART_START_DOCSTRING,
) )
class PLBartForSequenceClassification(PLBartPreTrainedModel): class PLBartForSequenceClassification(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig, **kwargs): def __init__(self, config: PLBartConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.model = PLBartModel(config) self.model = PLBartModel(config)
...@@ -1548,6 +1554,8 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel): ...@@ -1548,6 +1554,8 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
class PLBartForCausalLM(PLBartPreTrainedModel): class PLBartForCausalLM(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -859,11 +859,7 @@ class ProphetNetNgramSelfAttention(nn.Module): ...@@ -859,11 +859,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
): ):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size() batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [ assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
batch_size,
ngram_sequence_length,
hidden_size,
], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}" f" {hidden_states.shape}"
) )
...@@ -1774,6 +1770,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1774,6 +1770,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING, PROPHETNET_START_DOCSTRING,
) )
class ProphetNetModel(ProphetNetPreTrainedModel): class ProphetNetModel(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
...@@ -1901,6 +1899,12 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1901,6 +1899,12 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING, PROPHETNET_START_DOCSTRING,
) )
class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"decoder.word_embeddings.weight",
"encoder.word_embeddings.weight",
"lm_head.weight",
]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
self.prophetnet = ProphetNetModel(config) self.prophetnet = ProphetNetModel(config)
...@@ -2111,6 +2115,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -2111,6 +2115,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING, PROPHETNET_START_DOCSTRING,
) )
class ProphetNetForCausalLM(ProphetNetPreTrainedModel): class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
# set config for CLM # set config for CLM
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1140,6 +1140,8 @@ class RealmBertModel(RealmPreTrainedModel): ...@@ -1140,6 +1140,8 @@ class RealmBertModel(RealmPreTrainedModel):
REALM_START_DOCSTRING, REALM_START_DOCSTRING,
) )
class RealmEmbedder(RealmPreTrainedModel): class RealmEmbedder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1368,6 +1370,8 @@ class RealmScorer(RealmPreTrainedModel): ...@@ -1368,6 +1370,8 @@ class RealmScorer(RealmPreTrainedModel):
REALM_START_DOCSTRING, REALM_START_DOCSTRING,
) )
class RealmKnowledgeAugEncoder(RealmPreTrainedModel): class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.realm = RealmBertModel(self.config) self.realm = RealmBertModel(self.config)
......
...@@ -2192,6 +2192,8 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -2192,6 +2192,8 @@ class ReformerModel(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING) @add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel): class ReformerModelWithLMHead(ReformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`." assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
......
...@@ -1051,6 +1051,8 @@ class RoCBertModel(RoCBertPreTrainedModel): ...@@ -1051,6 +1051,8 @@ class RoCBertModel(RoCBertPreTrainedModel):
ROC_BERT_START_DOCSTRING, ROC_BERT_START_DOCSTRING,
) )
class RoCBertForPreTraining(RoCBertPreTrainedModel): class RoCBertForPreTraining(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1235,7 +1237,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel): ...@@ -1235,7 +1237,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
@add_start_docstrings("""RoCBert Model with a `language modeling` head on top.""", ROC_BERT_START_DOCSTRING) @add_start_docstrings("""RoCBert Model with a `language modeling` head on top.""", ROC_BERT_START_DOCSTRING)
class RoCBertForMaskedLM(RoCBertPreTrainedModel): class RoCBertForMaskedLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert
def __init__(self, config): def __init__(self, config):
...@@ -1361,7 +1363,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel): ...@@ -1361,7 +1363,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
) )
class RoCBertForCausalLM(RoCBertPreTrainedModel): class RoCBertForCausalLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert
def __init__(self, config): def __init__(self, config):
......
...@@ -954,6 +954,8 @@ class RoFormerModel(RoFormerPreTrainedModel): ...@@ -954,6 +954,8 @@ class RoFormerModel(RoFormerPreTrainedModel):
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) @add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
class RoFormerForMaskedLM(RoFormerPreTrainedModel): class RoFormerForMaskedLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1055,8 +1057,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel): ...@@ -1055,8 +1057,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
"""RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
) )
class RoFormerForCausalLM(RoFormerPreTrainedModel): class RoFormerForCausalLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1256,6 +1256,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1256,6 +1256,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
r"decoder.version", r"decoder.version",
r"model.encoder.embed_positions.weights", r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights", r"model.decoder.embed_positions.weights",
r"lm_head.weight",
] ]
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = [
r"model.encoder.embed_positions.weights", r"model.encoder.embed_positions.weights",
......
...@@ -745,6 +745,8 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel): ...@@ -745,6 +745,8 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel):
SPEECH_TO_TEXT_2_START_DOCSTRING, SPEECH_TO_TEXT_2_START_DOCSTRING,
) )
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -648,7 +648,11 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel): ...@@ -648,7 +648,11 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top.""", SQUEEZEBERT_START_DOCSTRING) @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top.""", SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1758,9 +1758,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1758,9 +1758,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
T5_START_DOCSTRING, T5_START_DOCSTRING,
) )
class T5EncoderModel(T5PreTrainedModel): class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
r"encoder.embed_tokens.weight",
]
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__(config) super().__init__(config)
......
...@@ -1004,6 +1004,7 @@ class TapasModel(TapasPreTrainedModel): ...@@ -1004,6 +1004,7 @@ class TapasModel(TapasPreTrainedModel):
@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING) @add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
class TapasForMaskedLM(TapasPreTrainedModel): class TapasForMaskedLM(TapasPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
config_class = TapasConfig config_class = TapasConfig
base_model_prefix = "tapas" base_model_prefix = "tapas"
......
...@@ -1006,6 +1006,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1006,6 +1006,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_START_DOCSTRING,
) )
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = TransfoXLModel(config) self.transformer = TransfoXLModel(config)
......
...@@ -785,6 +785,8 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel): ...@@ -785,6 +785,8 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
TROCR_START_DOCSTRING, TROCR_START_DOCSTRING,
) )
class TrOCRForCausalLM(TrOCRPreTrainedModel): class TrOCRForCausalLM(TrOCRPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -890,6 +890,8 @@ class ViltPooler(nn.Module): ...@@ -890,6 +890,8 @@ class ViltPooler(nn.Module):
VILT_START_DOCSTRING, VILT_START_DOCSTRING,
) )
class ViltForMaskedLM(ViltPreTrainedModel): class ViltForMaskedLM(ViltPreTrainedModel):
_keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -871,6 +871,8 @@ class VisualBertModel(VisualBertPreTrainedModel): ...@@ -871,6 +871,8 @@ class VisualBertModel(VisualBertPreTrainedModel):
VISUAL_BERT_START_DOCSTRING, VISUAL_BERT_START_DOCSTRING,
) )
class VisualBertForPreTraining(VisualBertPreTrainedModel): class VisualBertForPreTraining(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1459,6 +1461,8 @@ class VisualBertRegionToPhraseAttention(nn.Module): ...@@ -1459,6 +1461,8 @@ class VisualBertRegionToPhraseAttention(nn.Module):
VISUAL_BERT_START_DOCSTRING, VISUAL_BERT_START_DOCSTRING,
) )
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel): class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -825,6 +825,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel): ...@@ -825,6 +825,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
r"model.embed_positions.weights", r"model.embed_positions.weights",
r"embed_positions.weights",
r"lm_head.weight", r"lm_head.weight",
] ]
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = [
......
...@@ -673,6 +673,8 @@ class XLMPredLayer(nn.Module): ...@@ -673,6 +673,8 @@ class XLMPredLayer(nn.Module):
XLM_START_DOCSTRING, XLM_START_DOCSTRING,
) )
class XLMWithLMHeadModel(XLMPreTrainedModel): class XLMWithLMHeadModel(XLMPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
......
...@@ -876,11 +876,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module): ...@@ -876,11 +876,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
): ):
batch_size, ngram_sequence_length, hidden_size = hidden_states.size() batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
assert list(hidden_states.size()) == [ assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
batch_size,
ngram_sequence_length,
hidden_size,
], (
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
f" {hidden_states.shape}" f" {hidden_states.shape}"
) )
...@@ -1798,6 +1794,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1798,6 +1794,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
) )
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetModel(XLMProphetNetPreTrainedModel): class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
super().__init__(config) super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
...@@ -1926,6 +1924,12 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel): ...@@ -1926,6 +1924,12 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
) )
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"decoder.word_embeddings.weight",
"encoder.word_embeddings.weight",
"lm_head.weight",
]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
super().__init__(config) super().__init__(config)
self.prophetnet = XLMProphetNetModel(config) self.prophetnet = XLMProphetNetModel(config)
...@@ -2139,6 +2143,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ...@@ -2139,6 +2143,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
) )
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
# set config for CLM # set config for CLM
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment