"model_cards/vscode:/vscode.git/clone" did not exist on "79f0118c7284fdfe5ffca5090087b42bb230b6ca"
Unverified Commit bac2d29a authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Attempting to test automatically the `_keys_to_ignore`. (#20042)



* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d606d566
...@@ -624,6 +624,8 @@ class FNetModel(FNetPreTrainedModel): ...@@ -624,6 +624,8 @@ class FNetModel(FNetPreTrainedModel):
FNET_START_DOCSTRING, FNET_START_DOCSTRING,
) )
class FNetForPreTraining(FNetPreTrainedModel): class FNetForPreTraining(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -716,6 +718,8 @@ class FNetForPreTraining(FNetPreTrainedModel): ...@@ -716,6 +718,8 @@ class FNetForPreTraining(FNetPreTrainedModel):
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING) @add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
class FNetForMaskedLM(FNetPreTrainedModel): class FNetForMaskedLM(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -992,6 +992,8 @@ def _get_shape(t): ...@@ -992,6 +992,8 @@ def _get_shape(t):
FSMT_START_DOCSTRING, FSMT_START_DOCSTRING,
) )
class FSMTModel(PretrainedFSMTModel): class FSMTModel(PretrainedFSMTModel):
_keys_to_ignore_on_load_missing = ["decoder.output_projection.weight"]
def __init__(self, config: FSMTConfig): def __init__(self, config: FSMTConfig):
super().__init__(config) super().__init__(config)
...@@ -1120,6 +1122,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1120,6 +1122,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
"decoder.output_projection.weight",
] ]
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = [
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
......
...@@ -1193,6 +1193,8 @@ class FunnelForPreTraining(FunnelPreTrainedModel): ...@@ -1193,6 +1193,8 @@ class FunnelForPreTraining(FunnelPreTrainedModel):
@add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING) @add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
class FunnelForMaskedLM(FunnelPreTrainedModel): class FunnelForMaskedLM(FunnelPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: FunnelConfig) -> None: def __init__(self, config: FunnelConfig) -> None:
super().__init__(config) super().__init__(config)
......
...@@ -592,7 +592,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): ...@@ -592,7 +592,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
) )
class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "embed_out.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -856,7 +856,7 @@ class IBertModel(IBertPreTrainedModel): ...@@ -856,7 +856,7 @@ class IBertModel(IBertPreTrainedModel):
@add_start_docstrings("""I-BERT Model with a `language modeling` head on top.""", IBERT_START_DOCSTRING) @add_start_docstrings("""I-BERT Model with a `language modeling` head on top.""", IBERT_START_DOCSTRING)
class IBertForMaskedLM(IBertPreTrainedModel): class IBertForMaskedLM(IBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
......
...@@ -849,6 +849,12 @@ class LayoutLMModel(LayoutLMPreTrainedModel): ...@@ -849,6 +849,12 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING) @add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -2212,6 +2212,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2212,6 +2212,8 @@ class LEDDecoder(LEDPreTrainedModel):
LED_START_DOCSTRING, LED_START_DOCSTRING,
) )
class LEDModel(LEDPreTrainedModel): class LEDModel(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
super().__init__(config) super().__init__(config)
...@@ -2341,6 +2343,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2341,6 +2343,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
r"encoder.version", r"encoder.version",
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
] ]
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
...@@ -2528,6 +2532,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2528,6 +2532,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
LED_START_DOCSTRING, LED_START_DOCSTRING,
) )
class LEDForSequenceClassification(LEDPreTrainedModel): class LEDForSequenceClassification(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig, **kwargs): def __init__(self, config: LEDConfig, **kwargs):
warnings.warn( warnings.warn(
"The `transformers.LEDForSequenceClassification` class is deprecated and will be removed in version 5 of" "The `transformers.LEDForSequenceClassification` class is deprecated and will be removed in version 5 of"
...@@ -2662,6 +2668,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2662,6 +2668,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
LED_START_DOCSTRING, LED_START_DOCSTRING,
) )
class LEDForQuestionAnswering(LEDPreTrainedModel): class LEDForQuestionAnswering(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1775,7 +1775,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1775,7 +1775,7 @@ class LongformerModel(LongformerPreTrainedModel):
@add_start_docstrings("""Longformer Model with a `language modeling` head on top.""", LONGFORMER_START_DOCSTRING) @add_start_docstrings("""Longformer Model with a `language modeling` head on top.""", LONGFORMER_START_DOCSTRING)
class LongformerForMaskedLM(LongformerPreTrainedModel): class LongformerForMaskedLM(LongformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
......
...@@ -2137,9 +2137,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): ...@@ -2137,9 +2137,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
LONGT5_START_DOCSTRING, LONGT5_START_DOCSTRING,
) )
class LongT5EncoderModel(LongT5PreTrainedModel): class LongT5EncoderModel(LongT5PreTrainedModel):
authorized_missing_keys = [ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
r"encoder.embed_tokens.weight",
]
def __init__(self, config: LongT5Config): def __init__(self, config: LongT5Config):
super().__init__(config) super().__init__(config)
......
...@@ -1023,6 +1023,8 @@ class LxmertModel(LxmertPreTrainedModel): ...@@ -1023,6 +1023,8 @@ class LxmertModel(LxmertPreTrainedModel):
LXMERT_START_DOCSTRING, LXMERT_START_DOCSTRING,
) )
class LxmertForPreTraining(LxmertPreTrainedModel): class LxmertForPreTraining(LxmertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
# Configuration # Configuration
......
...@@ -1128,6 +1128,8 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1128,6 +1128,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
M2M_100_START_DOCSTRING, M2M_100_START_DOCSTRING,
) )
class M2M100Model(M2M100PreTrainedModel): class M2M100Model(M2M100PreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
super().__init__(config) super().__init__(config)
...@@ -1244,12 +1246,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1244,12 +1246,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
r"encoder.version", r"encoder.version",
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
r"model.encoder.embed_positions.weights", r"encoder.embed_tokens.weight",
r"model.decoder.embed_positions.weights", r"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = [
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
] ]
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
......
...@@ -1087,6 +1087,8 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1087,6 +1087,8 @@ class MarianDecoder(MarianPreTrainedModel):
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING "The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
) )
class MarianModel(MarianPreTrainedModel): class MarianModel(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
super().__init__(config) super().__init__(config)
...@@ -1278,6 +1280,8 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1278,6 +1280,8 @@ class MarianMTModel(MarianPreTrainedModel):
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
r"embed_positions", r"embed_positions",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
] ]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"] _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
...@@ -1540,6 +1544,8 @@ class MarianDecoderWrapper(MarianPreTrainedModel): ...@@ -1540,6 +1544,8 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
class MarianForCausalLM(MarianPreTrainedModel): class MarianForCausalLM(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -1150,6 +1150,8 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1150,6 +1150,8 @@ class MBartDecoder(MBartPreTrainedModel):
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class MBartModel(MBartPreTrainedModel): class MBartModel(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
super().__init__(config) super().__init__(config)
...@@ -1273,6 +1275,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1273,6 +1275,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
r"encoder.version", r"encoder.version",
r"decoder.version", r"decoder.version",
r"lm_head.weight", r"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
] ]
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
...@@ -1440,6 +1444,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1440,6 +1444,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class MBartForSequenceClassification(MBartPreTrainedModel): class MBartForSequenceClassification(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig, **kwargs): def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.model = MBartModel(config) self.model = MBartModel(config)
...@@ -1568,6 +1574,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1568,6 +1574,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
MBART_START_DOCSTRING, MBART_START_DOCSTRING,
) )
class MBartForQuestionAnswering(MBartPreTrainedModel): class MBartForQuestionAnswering(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1701,6 +1709,8 @@ class MBartDecoderWrapper(MBartPreTrainedModel): ...@@ -1701,6 +1709,8 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
class MBartForCausalLM(MBartPreTrainedModel): class MBartForCausalLM(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -1009,6 +1009,8 @@ class MegatronBertModel(MegatronBertPreTrainedModel): ...@@ -1009,6 +1009,8 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
MEGATRON_BERT_START_DOCSTRING, MEGATRON_BERT_START_DOCSTRING,
) )
class MegatronBertForPreTraining(MegatronBertPreTrainedModel): class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
def __init__(self, config, add_binary_head=True): def __init__(self, config, add_binary_head=True):
super().__init__(config) super().__init__(config)
...@@ -1115,7 +1117,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel): ...@@ -1115,7 +1117,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
class MegatronBertForCausalLM(MegatronBertPreTrainedModel): class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1261,7 +1263,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): ...@@ -1261,7 +1263,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"] _keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -925,6 +925,12 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -925,6 +925,12 @@ class MobileBertModel(MobileBertPreTrainedModel):
MOBILEBERT_START_DOCSTRING, MOBILEBERT_START_DOCSTRING,
) )
class MobileBertForPreTraining(MobileBertPreTrainedModel): class MobileBertForPreTraining(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.mobilebert = MobileBertModel(config) self.mobilebert = MobileBertModel(config)
...@@ -1033,6 +1039,11 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -1033,6 +1039,11 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
class MobileBertForMaskedLM(MobileBertPreTrainedModel): class MobileBertForMaskedLM(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -574,7 +574,7 @@ class MPNetModel(MPNetPreTrainedModel): ...@@ -574,7 +574,7 @@ class MPNetModel(MPNetPreTrainedModel):
class MPNetForMaskedLM(MPNetPreTrainedModel): class MPNetForMaskedLM(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config): def __init__(self, config):
......
...@@ -1292,6 +1292,7 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1292,6 +1292,7 @@ class MvpDecoder(MvpPreTrainedModel):
) )
class MvpModel(MvpPreTrainedModel): class MvpModel(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"] _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig): def __init__(self, config: MvpConfig):
super().__init__(config) super().__init__(config)
...@@ -1429,6 +1430,8 @@ class MvpModel(MvpPreTrainedModel): ...@@ -1429,6 +1430,8 @@ class MvpModel(MvpPreTrainedModel):
"The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING "The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING
) )
class MvpForConditionalGeneration(MvpPreTrainedModel): class MvpForConditionalGeneration(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MvpConfig): def __init__(self, config: MvpConfig):
super().__init__(config) super().__init__(config)
self.model = MvpModel(config) self.model = MvpModel(config)
...@@ -1600,6 +1603,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): ...@@ -1600,6 +1603,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
) )
class MvpForSequenceClassification(MvpPreTrainedModel): class MvpForSequenceClassification(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"] _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MvpConfig, **kwargs): def __init__(self, config: MvpConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
...@@ -1727,6 +1731,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel): ...@@ -1727,6 +1731,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
) )
class MvpForQuestionAnswering(MvpPreTrainedModel): class MvpForQuestionAnswering(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"] _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1856,6 +1861,8 @@ class MvpDecoderWrapper(MvpPreTrainedModel): ...@@ -1856,6 +1861,8 @@ class MvpDecoderWrapper(MvpPreTrainedModel):
class MvpForCausalLM(MvpPreTrainedModel): class MvpForCausalLM(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
......
...@@ -1038,6 +1038,8 @@ class NezhaModel(NezhaPreTrainedModel): ...@@ -1038,6 +1038,8 @@ class NezhaModel(NezhaPreTrainedModel):
NEZHA_START_DOCSTRING, NEZHA_START_DOCSTRING,
) )
class NezhaForPreTraining(NezhaPreTrainedModel): class NezhaForPreTraining(NezhaPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1140,7 +1142,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel): ...@@ -1140,7 +1142,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
class NezhaForMaskedLM(NezhaPreTrainedModel): class NezhaForMaskedLM(NezhaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", r"positions_encoding"] _keys_to_ignore_on_load_missing = [r"cls.predictions.decoder", r"positions_encoding"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -660,6 +660,8 @@ class NystromformerModel(NystromformerPreTrainedModel): ...@@ -660,6 +660,8 @@ class NystromformerModel(NystromformerPreTrainedModel):
@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING) @add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
class NystromformerForMaskedLM(NystromformerPreTrainedModel): class NystromformerForMaskedLM(NystromformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -531,6 +531,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -531,6 +531,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_START_DOCSTRING,
) )
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config)
...@@ -621,6 +623,8 @@ input sequence). ...@@ -621,6 +623,8 @@ input sequence).
OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_START_DOCSTRING,
) )
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment