"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "62147ec9c661bf1dc72534e116b4024433ad04c2"
Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
...@@ -1509,6 +1509,7 @@ class NllbMoeModel(NllbMoePreTrainedModel): ...@@ -1509,6 +1509,7 @@ class NllbMoeModel(NllbMoePreTrainedModel):
"decoder.embed_positions.weights", "decoder.embed_positions.weights",
"decoder.embed_positions.bias", "decoder.embed_positions.bias",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: NllbMoeConfig): def __init__(self, config: NllbMoeConfig):
super().__init__(config) super().__init__(config)
...@@ -1652,6 +1653,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): ...@@ -1652,6 +1653,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
r"decoder.embed_positions.weights", r"decoder.embed_positions.weights",
r"decoder.embed_positions.bias", r"decoder.embed_positions.bias",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: NllbMoeConfig): def __init__(self, config: NllbMoeConfig):
super().__init__(config) super().__init__(config)
......
...@@ -659,6 +659,7 @@ class NystromformerModel(NystromformerPreTrainedModel): ...@@ -659,6 +659,7 @@ class NystromformerModel(NystromformerPreTrainedModel):
@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING) @add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
class NystromformerForMaskedLM(NystromformerPreTrainedModel): class NystromformerForMaskedLM(NystromformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -530,6 +530,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -530,6 +530,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
) )
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -621,6 +622,7 @@ input sequence). ...@@ -621,6 +622,7 @@ input sequence).
) )
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -818,6 +818,7 @@ class OPTModel(OPTPreTrainedModel): ...@@ -818,6 +818,7 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1152,6 +1152,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1152,6 +1152,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
) )
class PegasusModel(PegasusPreTrainedModel): class PegasusModel(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusConfig): def __init__(self, config: PegasusConfig):
super().__init__(config) super().__init__(config)
...@@ -1312,6 +1313,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1312,6 +1313,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PegasusConfig): def __init__(self, config: PegasusConfig):
super().__init__(config) super().__init__(config)
...@@ -1512,6 +1514,7 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel): ...@@ -1512,6 +1514,7 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
class PegasusForCausalLM(PegasusPreTrainedModel): class PegasusForCausalLM(PegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1387,6 +1387,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1387,6 +1387,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
) )
class PegasusXModel(PegasusXPreTrainedModel): class PegasusXModel(PegasusXPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PegasusXConfig): def __init__(self, config: PegasusXConfig):
super().__init__(config) super().__init__(config)
...@@ -1538,6 +1539,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): ...@@ -1538,6 +1539,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PegasusXConfig): def __init__(self, config: PegasusXConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1317,6 +1317,7 @@ PIX2STRUCT_INPUTS_DOCSTRING = r""" ...@@ -1317,6 +1317,7 @@ PIX2STRUCT_INPUTS_DOCSTRING = r"""
class Pix2StructTextModel(Pix2StructPreTrainedModel): class Pix2StructTextModel(Pix2StructPreTrainedModel):
config_class = Pix2StructTextConfig config_class = Pix2StructTextConfig
_no_split_modules = ["Pix2StructTextBlock"] _no_split_modules = ["Pix2StructTextBlock"]
_tied_weights_keys = ["lm_head.weight"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
...@@ -1604,6 +1605,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1604,6 +1605,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.layer.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.layer.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["decoder.lm_head.weight"]
def __init__(self, config: Pix2StructConfig): def __init__(self, config: Pix2StructConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1128,6 +1128,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1128,6 +1128,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
) )
class PLBartModel(PLBartPreTrainedModel): class PLBartModel(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig): def __init__(self, config: PLBartConfig):
super().__init__(config) super().__init__(config)
...@@ -1253,6 +1254,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1253,6 +1254,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: PLBartConfig): def __init__(self, config: PLBartConfig):
super().__init__(config) super().__init__(config)
...@@ -1417,6 +1419,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1417,6 +1419,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
) )
class PLBartForSequenceClassification(PLBartPreTrainedModel): class PLBartForSequenceClassification(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: PLBartConfig, **kwargs): def __init__(self, config: PLBartConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
...@@ -1555,6 +1558,7 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel): ...@@ -1555,6 +1558,7 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
class PLBartForCausalLM(PLBartPreTrainedModel): class PLBartForCausalLM(PLBartPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1745,6 +1745,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1745,6 +1745,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
) )
class ProphetNetModel(ProphetNetPreTrainedModel): class ProphetNetModel(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"] _keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
...@@ -1878,6 +1879,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1878,6 +1879,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
"encoder.word_embeddings.weight", "encoder.word_embeddings.weight",
"lm_head.weight", "lm_head.weight",
] ]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
...@@ -2090,6 +2092,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -2090,6 +2092,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
) )
class ProphetNetForCausalLM(ProphetNetPreTrainedModel): class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
# set config for CLM # set config for CLM
......
...@@ -1014,6 +1014,7 @@ class QDQBertModel(QDQBertPreTrainedModel): ...@@ -1014,6 +1014,7 @@ class QDQBertModel(QDQBertPreTrainedModel):
class QDQBertLMHeadModel(QDQBertPreTrainedModel): class QDQBertLMHeadModel(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1167,6 +1168,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel): ...@@ -1167,6 +1168,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
class QDQBertForMaskedLM(QDQBertPreTrainedModel): class QDQBertForMaskedLM(QDQBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1148,6 +1148,7 @@ class RealmBertModel(RealmPreTrainedModel): ...@@ -1148,6 +1148,7 @@ class RealmBertModel(RealmPreTrainedModel):
) )
class RealmEmbedder(RealmPreTrainedModel): class RealmEmbedder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1378,6 +1379,7 @@ class RealmScorer(RealmPreTrainedModel): ...@@ -1378,6 +1379,7 @@ class RealmScorer(RealmPreTrainedModel):
) )
class RealmKnowledgeAugEncoder(RealmPreTrainedModel): class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
_tied_weights_keys = ["cls.predictions.decoder"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -2186,6 +2186,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -2186,6 +2186,7 @@ class ReformerModel(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING) @add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerModelWithLMHead(ReformerPreTrainedModel): class ReformerModelWithLMHead(ReformerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -2311,6 +2312,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2311,6 +2312,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING) @add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
class ReformerForMaskedLM(ReformerPreTrainedModel): class ReformerForMaskedLM(ReformerPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert not config.is_decoder, ( assert not config.is_decoder, (
......
...@@ -912,6 +912,8 @@ class RemBertModel(RemBertPreTrainedModel): ...@@ -912,6 +912,8 @@ class RemBertModel(RemBertPreTrainedModel):
@add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING) @add_start_docstrings("""RemBERT Model with a `language modeling` head on top.""", REMBERT_START_DOCSTRING)
class RemBertForMaskedLM(RemBertPreTrainedModel): class RemBertForMaskedLM(RemBertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1015,6 +1017,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel): ...@@ -1015,6 +1017,7 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
) )
class RemBertForCausalLM(RemBertPreTrainedModel): class RemBertForCausalLM(RemBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -884,6 +884,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -884,6 +884,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1038,6 +1039,7 @@ class RobertaForMaskedLM(RobertaPreTrainedModel): ...@@ -1038,6 +1039,7 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -889,6 +889,7 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): ...@@ -889,6 +889,7 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1047,6 +1048,7 @@ class RobertaPreLayerNormForMaskedLM(RobertaPreLayerNormPreTrainedModel): ...@@ -1047,6 +1048,7 @@ class RobertaPreLayerNormForMaskedLM(RobertaPreLayerNormPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
def __init__(self, config): def __init__(self, config):
......
...@@ -1082,6 +1082,7 @@ class RoCBertModel(RoCBertPreTrainedModel): ...@@ -1082,6 +1082,7 @@ class RoCBertModel(RoCBertPreTrainedModel):
) )
class RoCBertForPreTraining(RoCBertPreTrainedModel): class RoCBertForPreTraining(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1268,6 +1269,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel): ...@@ -1268,6 +1269,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
class RoCBertForMaskedLM(RoCBertPreTrainedModel): class RoCBertForMaskedLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert
def __init__(self, config): def __init__(self, config):
...@@ -1409,6 +1411,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel): ...@@ -1409,6 +1411,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
class RoCBertForCausalLM(RoCBertPreTrainedModel): class RoCBertForCausalLM(RoCBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert
def __init__(self, config): def __init__(self, config):
......
...@@ -953,6 +953,7 @@ class RoFormerModel(RoFormerPreTrainedModel): ...@@ -953,6 +953,7 @@ class RoFormerModel(RoFormerPreTrainedModel):
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) @add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
class RoFormerForMaskedLM(RoFormerPreTrainedModel): class RoFormerForMaskedLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1055,6 +1056,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel): ...@@ -1055,6 +1056,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
) )
class RoFormerForCausalLM(RoFormerPreTrainedModel): class RoFormerForCausalLM(RoFormerPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -732,6 +732,8 @@ class RwkvModel(RwkvPreTrainedModel): ...@@ -732,6 +732,8 @@ class RwkvModel(RwkvPreTrainedModel):
RWKV_START_DOCSTRING, RWKV_START_DOCSTRING,
) )
class RwkvForCausalLM(RwkvPreTrainedModel): class RwkvForCausalLM(RwkvPreTrainedModel):
_tied_weights_keys = ["head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.rwkv = RwkvModel(config) self.rwkv = RwkvModel(config)
......
...@@ -1191,6 +1191,7 @@ SAM_INPUTS_DOCSTRING = r""" ...@@ -1191,6 +1191,7 @@ SAM_INPUTS_DOCSTRING = r"""
) )
class SamModel(SamPreTrainedModel): class SamModel(SamPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"] _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
_tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1272,6 +1272,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1272,6 +1272,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
r"model.encoder.embed_positions.weights", r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights", r"model.decoder.embed_positions.weights",
] ]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment