"src/lib/vscode:/vscode.git/clone" did not exist on "dbc3aa7234550cf6a50ba1ce2a5330eb2f531f56"
Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
......@@ -1112,6 +1112,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
)
class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GPTSanJapaneseConfig):
super().__init__(config)
......
......@@ -856,6 +856,7 @@ class IBertModel(IBertPreTrainedModel):
class IBertForMaskedLM(IBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -894,6 +894,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
)
class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
......
......@@ -862,6 +862,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -2210,6 +2210,7 @@ class LEDDecoder(LEDPreTrainedModel):
)
class LEDModel(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig):
super().__init__(config)
......@@ -2342,6 +2343,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: LEDConfig):
super().__init__(config)
......@@ -2529,6 +2531,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
)
class LEDForSequenceClassification(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: LEDConfig, **kwargs):
warnings.warn(
......@@ -2665,6 +2668,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
)
class LEDForQuestionAnswering(LEDPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -611,6 +611,8 @@ class LlamaModel(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
......
......@@ -1772,6 +1772,7 @@ class LongformerModel(LongformerPreTrainedModel):
class LongformerForMaskedLM(LongformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1770,6 +1770,7 @@ class LongT5Model(LongT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: LongT5Config):
super().__init__(config)
......@@ -1924,6 +1925,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: LongT5Config):
super().__init__(config)
......@@ -2159,6 +2161,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
)
class LongT5EncoderModel(LongT5PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: LongT5Config):
super().__init__(config)
......
......@@ -1289,6 +1289,7 @@ class LukeForMaskedLM(LukePreTrainedModel):
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1019,6 +1019,7 @@ class LxmertModel(LxmertPreTrainedModel):
)
class LxmertForPreTraining(LxmertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1146,6 +1146,7 @@ class M2M100Model(M2M100PreTrainedModel):
"decoder.embed_positions.weights",
"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: M2M100Config):
super().__init__(config)
......@@ -1269,6 +1270,7 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: M2M100Config):
super().__init__(config)
......
......@@ -1099,6 +1099,7 @@ class MarianDecoder(MarianPreTrainedModel):
)
class MarianModel(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MarianConfig):
super().__init__(config)
......@@ -1294,8 +1295,8 @@ class MarianMTModel(MarianPreTrainedModel):
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MarianConfig):
super().__init__(config)
......@@ -1556,6 +1557,7 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
class MarianForCausalLM(MarianPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1152,6 +1152,7 @@ class MBartDecoder(MBartPreTrainedModel):
)
class MBartModel(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig):
super().__init__(config)
......@@ -1279,6 +1280,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MBartConfig):
super().__init__(config)
......@@ -1446,6 +1448,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
)
class MBartForSequenceClassification(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config: MBartConfig, **kwargs):
super().__init__(config, **kwargs)
......@@ -1575,6 +1578,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
)
class MBartForQuestionAnswering(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1708,6 +1712,7 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
class MBartForCausalLM(MBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1659,6 +1659,7 @@ class MegaForCausalLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_missing = [r"lm_head.weight", r"lm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: MegaConfig):
super().__init__(config)
......@@ -1823,6 +1824,7 @@ class MegaForMaskedLM(MegaPreTrainedModel):
_keys_to_ignore_on_save = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_missing = [r"mlm_head.weight", r"mlm_head.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["mlm_head.weight"]
def __init__(self, config: MegaConfig):
super().__init__(config)
......
......@@ -1015,6 +1015,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
)
class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config, add_binary_head=True):
super().__init__(config)
......@@ -1122,6 +1123,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
......@@ -1267,6 +1269,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
......
......@@ -928,6 +928,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1041,6 +1042,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
"cls.predictions.decoder.bias",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -572,6 +572,7 @@ class MPNetModel(MPNetPreTrainedModel):
class MPNetForMaskedLM(MPNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1324,6 +1324,7 @@ class MT5Model(MT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
def __init__(self, config: MT5Config):
......@@ -1556,6 +1557,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
def __init__(self, config: MT5Config):
......@@ -1898,6 +1900,7 @@ class MT5EncoderModel(MT5PreTrainedModel):
r"encoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5
def __init__(self, config: MT5Config):
......
......@@ -1297,6 +1297,7 @@ class MvpDecoder(MvpPreTrainedModel):
class MvpModel(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig):
super().__init__(config)
......@@ -1433,6 +1434,7 @@ class MvpModel(MvpPreTrainedModel):
)
class MvpForConditionalGeneration(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: MvpConfig):
super().__init__(config)
......@@ -1606,6 +1608,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
class MvpForSequenceClassification(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: MvpConfig, **kwargs):
super().__init__(config, **kwargs)
......@@ -1734,6 +1737,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
class MvpForQuestionAnswering(MvpPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1865,6 +1869,7 @@ class MvpDecoderWrapper(MvpPreTrainedModel):
class MvpForCausalLM(MvpPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1038,6 +1038,7 @@ class NezhaModel(NezhaPreTrainedModel):
)
class NezhaForPreTraining(NezhaPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
......@@ -1141,6 +1142,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
class NezhaForMaskedLM(NezhaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder", r"positions_encoding"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment