Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
......@@ -1509,6 +1509,7 @@ class NllbMoeModel(NllbMoePreTrainedModel):
"decoder.embed_positions.weights",
"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: NllbMoeConfig):
super().__init__(config)
......@@ -1652,6 +1653,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: NllbMoeConfig):
super().__init__(config)
......
......@@ -659,6 +659,7 @@ class NystromformerModel(NystromformerPreTrainedModel):
@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
class NystromformerForMaskedLM(NystromformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
......
......@@ -530,6 +530,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -621,6 +622,7 @@ input sequence).
)
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -818,6 +818,7 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1152,6 +1152,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
)
class PegasusModel(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusConfig):
super().__init__(config)
......@@ -1312,6 +1313,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PegasusConfig):
super().__init__(config)
......@@ -1512,6 +1514,7 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
class PegasusForCausalLM(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1387,6 +1387,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
)
class PegasusXModel(PegasusXPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusXConfig):
super().__init__(config)
......@@ -1538,6 +1539,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PegasusXConfig):
super().__init__(config)
......
......@@ -1317,6 +1317,7 @@ PIX2STRUCT_INPUTS_DOCSTRING = r"""
class Pix2StructTextModel(Pix2StructPreTrainedModel):
config_class = Pix2StructTextConfig
_no_split_modules = ["Pix2StructTextBlock"]
_tied_weights_keys = ["lm_head.weight"]
supports_gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, value=False):
......@@ -1604,6 +1605,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.layer.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["decoder.lm_head.weight"]
def __init__(self, config: Pix2StructConfig):
super().__init__(config)
......
......@@ -1128,6 +1128,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
)
class PLBartModel(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig):
super().__init__(config)
......@@ -1253,6 +1254,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PLBartConfig):
super().__init__(config)
......@@ -1417,6 +1419,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
)
class PLBartForSequenceClassification(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig, **kwargs):
super().__init__(config, **kwargs)
......@@ -1555,6 +1558,7 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
class PLBartForCausalLM(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1745,6 +1745,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
)
class ProphetNetModel(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
......@@ -1878,6 +1879,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
"encoder.word_embeddings.weight",
"lm_head.weight",
]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
def __init__(self, config: ProphetNetConfig):
super().__init__(config)
......@@ -2090,6 +2092,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
)
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: ProphetNetConfig):
# set config for CLM
......
......@@ -1014,6 +1014,7 @@ class QDQBertModel(QDQBertPreTrainedModel):
class QDQBertLMHeadModel(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1167,6 +1168,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
class QDQBertForMaskedLM(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1148,6 +1148,7 @@ class RealmBertModel(RealmPreTrainedModel):
)
class RealmEmbedder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1378,6 +1379,7 @@ class RealmScorer(RealmPreTrainedModel):
)
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config):
super().__init__(config)
......
......@@ -2186,6 +2186,7 @@ class ReformerModel(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -2311,6 +2312,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerForMaskedLM(ReformerPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
assert not config.is_decoder, (
......
......@@ -912,6 +912,8 @@ class RemBertModel(RemBertPreTrainedModel):
@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING)
class RemBertForMaskedLM(RemBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1015,6 +1017,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
)
class RemBertForCausalLM(RemBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -884,6 +884,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1038,6 +1039,7 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -889,6 +889,7 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1047,6 +1048,7 @@ class RobertaPreLayerNormForMaskedLM(RobertaPreLayerNormPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
def __init__(self, config):
......
......@@ -1082,6 +1082,7 @@ class RoCBertModel(RoCBertPreTrainedModel):
)
class RoCBertForPreTraining(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1268,6 +1269,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
class RoCBertForMaskedLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert
def __init__(self, config):
......@@ -1409,6 +1411,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
class RoCBertForCausalLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert
def __init__(self, config):
......
......@@ -953,6 +953,7 @@ class RoFormerModel(RoFormerPreTrainedModel):
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
class RoFormerForMaskedLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1055,6 +1056,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
)
class RoFormerForCausalLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -732,6 +732,8 @@ class RwkvModel(RwkvPreTrainedModel):
RWKV_START_DOCSTRING,
)
class RwkvForCausalLM(RwkvPreTrainedModel):
_tied_weights_keys = ["head.weight"]
def __init__(self, config):
super().__init__(config)
self.rwkv = RwkvModel(config)
......
......@@ -1191,6 +1191,7 @@ SAM_INPUTS_DOCSTRING = r"""
)
class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1272,6 +1272,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Speech2TextConfig):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment