Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
...@@ -1112,6 +1112,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): ...@@ -1112,6 +1112,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
) )
class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel): class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GPTSanJapaneseConfig): def __init__(self, config: GPTSanJapaneseConfig):
super().__init__(config) super().__init__(config)
......
...@@ -856,6 +856,7 @@ class IBertModel(IBertPreTrainedModel): ...@@ -856,6 +856,7 @@ class IBertModel(IBertPreTrainedModel):
class IBertForMaskedLM(IBertPreTrainedModel): class IBertForMaskedLM(IBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias", "lm_head.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -894,6 +894,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -894,6 +894,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
) )
class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: ImageGPTConfig): def __init__(self, config: ImageGPTConfig):
super().__init__(config) super().__init__(config)
......
...@@ -862,6 +862,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): ...@@ -862,6 +862,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
"cls.predictions.decoder.weight", "cls.predictions.decoder.weight",
"embeddings.position_ids", "embeddings.position_ids",
] ]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -2210,6 +2210,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2210,6 +2210,7 @@ class LEDDecoder(LEDPreTrainedModel):
) )
class LEDModel(LEDPreTrainedModel): class LEDModel(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
super().__init__(config) super().__init__(config)
...@@ -2342,6 +2343,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2342,6 +2343,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: LEDConfig): def __init__(self, config: LEDConfig):
super().__init__(config) super().__init__(config)
...@@ -2529,6 +2531,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2529,6 +2531,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
) )
class LEDForSequenceClassification(LEDPreTrainedModel): class LEDForSequenceClassification(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig, **kwargs): def __init__(self, config: LEDConfig, **kwargs):
warnings.warn( warnings.warn(
...@@ -2665,6 +2668,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2665,6 +2668,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
) )
class LEDForQuestionAnswering(LEDPreTrainedModel): class LEDForQuestionAnswering(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -611,6 +611,8 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -611,6 +611,8 @@ class LlamaModel(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel): class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = LlamaModel(config) self.model = LlamaModel(config)
......
...@@ -1772,6 +1772,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1772,6 +1772,7 @@ class LongformerModel(LongformerPreTrainedModel):
class LongformerForMaskedLM(LongformerPreTrainedModel): class LongformerForMaskedLM(LongformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder"] _keys_to_ignore_on_load_missing = ["lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1770,6 +1770,7 @@ class LongT5Model(LongT5PreTrainedModel): ...@@ -1770,6 +1770,7 @@ class LongT5Model(LongT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: LongT5Config): def __init__(self, config: LongT5Config):
super().__init__(config) super().__init__(config)
...@@ -1924,6 +1925,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): ...@@ -1924,6 +1925,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: LongT5Config): def __init__(self, config: LongT5Config):
super().__init__(config) super().__init__(config)
...@@ -2159,6 +2161,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): ...@@ -2159,6 +2161,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
) )
class LongT5EncoderModel(LongT5PreTrainedModel): class LongT5EncoderModel(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: LongT5Config): def __init__(self, config: LongT5Config):
super().__init__(config) super().__init__(config)
......
...@@ -1289,6 +1289,7 @@ class LukeForMaskedLM(LukePreTrainedModel): ...@@ -1289,6 +1289,7 @@ class LukeForMaskedLM(LukePreTrainedModel):
r"lm_head.decoder.bias", r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight", r"entity_predictions.decoder.weight",
] ]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1019,6 +1019,7 @@ class LxmertModel(LxmertPreTrainedModel): ...@@ -1019,6 +1019,7 @@ class LxmertModel(LxmertPreTrainedModel):
) )
class LxmertForPreTraining(LxmertPreTrainedModel): class LxmertForPreTraining(LxmertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1146,6 +1146,7 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1146,6 +1146,7 @@ class M2M100Model(M2M100PreTrainedModel):
"decoder.embed_positions.weights", "decoder.embed_positions.weights",
"decoder.embed_positions.bias", "decoder.embed_positions.bias",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
super().__init__(config) super().__init__(config)
...@@ -1269,6 +1270,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ...@@ -1269,6 +1270,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
r"decoder.embed_positions.weights", r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias", r"decoder.embed_positions.bias",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: M2M100Config): def __init__(self, config: M2M100Config):
super().__init__(config) super().__init__(config)
......
...@@ -1099,6 +1099,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1099,6 +1099,7 @@ class MarianDecoder(MarianPreTrainedModel):
) )
class MarianModel(MarianPreTrainedModel): class MarianModel(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
super().__init__(config) super().__init__(config)
...@@ -1294,8 +1295,8 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1294,8 +1295,8 @@ class MarianMTModel(MarianPreTrainedModel):
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
] ]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"] _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MarianConfig): def __init__(self, config: MarianConfig):
super().__init__(config) super().__init__(config)
...@@ -1556,6 +1557,7 @@ class MarianDecoderWrapper(MarianPreTrainedModel): ...@@ -1556,6 +1557,7 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
class MarianForCausalLM(MarianPreTrainedModel): class MarianForCausalLM(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1152,6 +1152,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1152,6 +1152,7 @@ class MBartDecoder(MBartPreTrainedModel):
) )
class MBartModel(MBartPreTrainedModel): class MBartModel(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
super().__init__(config) super().__init__(config)
...@@ -1279,6 +1280,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1279,6 +1280,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MBartConfig): def __init__(self, config: MBartConfig):
super().__init__(config) super().__init__(config)
...@@ -1446,6 +1448,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1446,6 +1448,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
) )
class MBartForSequenceClassification(MBartPreTrainedModel): class MBartForSequenceClassification(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig, **kwargs): def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
...@@ -1575,6 +1578,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1575,6 +1578,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
) )
class MBartForQuestionAnswering(MBartPreTrainedModel): class MBartForQuestionAnswering(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1708,6 +1712,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel): ...@@ -1708,6 +1712,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
class MBartForCausalLM(MBartPreTrainedModel): class MBartForCausalLM(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1659,6 +1659,7 @@ class MegaForCausalLM(MegaPreTrainedModel): ...@@ -1659,6 +1659,7 @@ class MegaForCausalLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.bias"] _keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_missing = [r"lm_head.weight", r"lm_head.bias"] _keys_to_ignore_on_load_missing = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MegaConfig): def __init__(self, config: MegaConfig):
super().__init__(config) super().__init__(config)
...@@ -1823,6 +1824,7 @@ class MegaForMaskedLM(MegaPreTrainedModel): ...@@ -1823,6 +1824,7 @@ class MegaForMaskedLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"mlm_head.weight", r"mlm_head.bias"] _keys_to_ignore_on_save = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_missing = [r"mlm_head.weight", r"mlm_head.bias"] _keys_to_ignore_on_load_missing = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["mlm_head.weight"]
def __init__(self, config: MegaConfig): def __init__(self, config: MegaConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1015,6 +1015,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel): ...@@ -1015,6 +1015,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
) )
class MegatronBertForPreTraining(MegatronBertPreTrainedModel): class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config, add_binary_head=True): def __init__(self, config, add_binary_head=True):
super().__init__(config) super().__init__(config)
...@@ -1122,6 +1123,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel): ...@@ -1122,6 +1123,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
class MegatronBertForCausalLM(MegatronBertPreTrainedModel): class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"] _keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1267,6 +1269,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): ...@@ -1267,6 +1269,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"] _keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -928,6 +928,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -928,6 +928,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
"cls.predictions.decoder.bias", "cls.predictions.decoder.bias",
"embeddings.position_ids", "embeddings.position_ids",
] ]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1041,6 +1042,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): ...@@ -1041,6 +1042,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
"cls.predictions.decoder.bias", "cls.predictions.decoder.bias",
"embeddings.position_ids", "embeddings.position_ids",
] ]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -572,6 +572,7 @@ class MPNetModel(MPNetPreTrainedModel): ...@@ -572,6 +572,7 @@ class MPNetModel(MPNetPreTrainedModel):
class MPNetForMaskedLM(MPNetPreTrainedModel): class MPNetForMaskedLM(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1324,6 +1324,7 @@ class MT5Model(MT5PreTrainedModel): ...@@ -1324,6 +1324,7 @@ class MT5Model(MT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
def __init__(self, config: MT5Config): def __init__(self, config: MT5Config):
...@@ -1556,6 +1557,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1556,6 +1557,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
def __init__(self, config: MT5Config): def __init__(self, config: MT5Config):
...@@ -1898,6 +1900,7 @@ class MT5EncoderModel(MT5PreTrainedModel): ...@@ -1898,6 +1900,7 @@ class MT5EncoderModel(MT5PreTrainedModel):
r"encoder.embed_tokens.weight", r"encoder.embed_tokens.weight",
] ]
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5
def __init__(self, config: MT5Config): def __init__(self, config: MT5Config):
......
...@@ -1297,6 +1297,7 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1297,6 +1297,7 @@ class MvpDecoder(MvpPreTrainedModel):
class MvpModel(MvpPreTrainedModel): class MvpModel(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"] _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig): def __init__(self, config: MvpConfig):
super().__init__(config) super().__init__(config)
...@@ -1433,6 +1434,7 @@ class MvpModel(MvpPreTrainedModel): ...@@ -1433,6 +1434,7 @@ class MvpModel(MvpPreTrainedModel):
) )
class MvpForConditionalGeneration(MvpPreTrainedModel): class MvpForConditionalGeneration(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MvpConfig): def __init__(self, config: MvpConfig):
super().__init__(config) super().__init__(config)
...@@ -1606,6 +1608,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): ...@@ -1606,6 +1608,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
class MvpForSequenceClassification(MvpPreTrainedModel): class MvpForSequenceClassification(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"] _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig, **kwargs): def __init__(self, config: MvpConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
...@@ -1734,6 +1737,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel): ...@@ -1734,6 +1737,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
class MvpForQuestionAnswering(MvpPreTrainedModel): class MvpForQuestionAnswering(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"] _keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1865,6 +1869,7 @@ class MvpDecoderWrapper(MvpPreTrainedModel): ...@@ -1865,6 +1869,7 @@ class MvpDecoderWrapper(MvpPreTrainedModel):
class MvpForCausalLM(MvpPreTrainedModel): class MvpForCausalLM(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1038,6 +1038,7 @@ class NezhaModel(NezhaPreTrainedModel): ...@@ -1038,6 +1038,7 @@ class NezhaModel(NezhaPreTrainedModel):
) )
class NezhaForPreTraining(NezhaPreTrainedModel): class NezhaForPreTraining(NezhaPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1141,6 +1142,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel): ...@@ -1141,6 +1142,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
class NezhaForMaskedLM(NezhaPreTrainedModel): class NezhaForMaskedLM(NezhaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder", r"positions_encoding"] _keys_to_ignore_on_load_missing = [r"cls.predictions.decoder", r"positions_encoding"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment