Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
...@@ -1022,6 +1022,7 @@ class DebertaModel(DebertaPreTrainedModel): ...@@ -1022,6 +1022,7 @@ class DebertaModel(DebertaPreTrainedModel):
class DebertaForMaskedLM(DebertaPreTrainedModel): class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1122,6 +1122,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel): ...@@ -1122,6 +1122,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1824,6 +1824,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): ...@@ -1824,6 +1824,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"] _keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
def __init__(self, config: DeformableDetrConfig): def __init__(self, config: DeformableDetrConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1776,6 +1776,7 @@ class DetaModel(DetaPreTrainedModel): ...@@ -1776,6 +1776,7 @@ class DetaModel(DetaPreTrainedModel):
class DetaForObjectDetection(DetaPreTrainedModel): class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"] _keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.\d+"]
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
def __init__(self, config: DetaConfig): def __init__(self, config: DetaConfig):
......
...@@ -596,6 +596,7 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -596,6 +596,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
) )
class DistilBertForMaskedLM(DistilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"] _keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
_tied_weights_keys = ["vocab_projector.weight"]
def __init__(self, config: PretrainedConfig): def __init__(self, config: PretrainedConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1167,6 +1167,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -1167,6 +1167,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
) )
class ElectraForMaskedLM(ElectraPreTrainedModel): class ElectraForMaskedLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"] _keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1534,6 +1535,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): ...@@ -1534,6 +1535,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
) )
class ElectraForCausalLM(ElectraPreTrainedModel): class ElectraForCausalLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"] _keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -984,6 +984,7 @@ class ErnieModel(ErniePreTrainedModel): ...@@ -984,6 +984,7 @@ class ErnieModel(ErniePreTrainedModel):
) )
class ErnieForPreTraining(ErniePreTrainedModel): class ErnieForPreTraining(ErniePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config): def __init__(self, config):
...@@ -1096,6 +1097,7 @@ class ErnieForPreTraining(ErniePreTrainedModel): ...@@ -1096,6 +1097,7 @@ class ErnieForPreTraining(ErniePreTrainedModel):
class ErnieForCausalLM(ErniePreTrainedModel): class ErnieForCausalLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
def __init__(self, config): def __init__(self, config):
...@@ -1243,6 +1245,7 @@ class ErnieForCausalLM(ErniePreTrainedModel): ...@@ -1243,6 +1245,7 @@ class ErnieForCausalLM(ErniePreTrainedModel):
class ErnieForMaskedLM(ErniePreTrainedModel): class ErnieForMaskedLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config): def __init__(self, config):
......
...@@ -962,6 +962,7 @@ class EsmModel(EsmPreTrainedModel): ...@@ -962,6 +962,7 @@ class EsmModel(EsmPreTrainedModel):
class EsmForMaskedLM(EsmPreTrainedModel): class EsmForMaskedLM(EsmPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", "lm_head.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -655,6 +655,7 @@ class FlaubertModel(FlaubertPreTrainedModel): ...@@ -655,6 +655,7 @@ class FlaubertModel(FlaubertPreTrainedModel):
# Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert # Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
class FlaubertWithLMHeadModel(FlaubertPreTrainedModel): class FlaubertWithLMHeadModel(FlaubertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"] _keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
_tied_weights_keys = ["pred_layer.proj.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1730,6 +1730,12 @@ class FlavaForPreTraining(FlavaPreTrainedModel): ...@@ -1730,6 +1730,12 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
"mlm_head.decoder.bias", "mlm_head.decoder.bias",
"mim_head.decoder.bias", "mim_head.decoder.bias",
] ]
_tied_weights_keys = [
"mmm_text_head.decoder.bias",
"mmm_image_head.decoder.bias",
"mlm_head.decoder.bias",
"mim_head.decoder.bias",
]
def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None): def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
super().__init__(config) super().__init__(config)
......
...@@ -622,6 +622,7 @@ class FNetModel(FNetPreTrainedModel): ...@@ -622,6 +622,7 @@ class FNetModel(FNetPreTrainedModel):
) )
class FNetForPreTraining(FNetPreTrainedModel): class FNetForPreTraining(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -716,6 +717,7 @@ class FNetForPreTraining(FNetPreTrainedModel): ...@@ -716,6 +717,7 @@ class FNetForPreTraining(FNetPreTrainedModel):
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING) @add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
class FNetForMaskedLM(FNetPreTrainedModel): class FNetForMaskedLM(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1035,6 +1035,7 @@ def _get_shape(t): ...@@ -1035,6 +1035,7 @@ def _get_shape(t):
) )
class FSMTModel(PretrainedFSMTModel): class FSMTModel(PretrainedFSMTModel):
_keys_to_ignore_on_load_missing = ["decoder.output_projection.weight"] _keys_to_ignore_on_load_missing = ["decoder.output_projection.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight"]
def __init__(self, config: FSMTConfig): def __init__(self, config: FSMTConfig):
super().__init__(config) super().__init__(config)
...@@ -1180,6 +1181,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1180,6 +1181,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
"model.encoder.embed_positions.weight", "model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight", "model.decoder.embed_positions.weight",
] ]
_tied_weights_keys = ["model.decoder.embed_tokens.weight"]
def __init__(self, config: FSMTConfig): def __init__(self, config: FSMTConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1191,6 +1191,7 @@ class FunnelForPreTraining(FunnelPreTrainedModel): ...@@ -1191,6 +1191,7 @@ class FunnelForPreTraining(FunnelPreTrainedModel):
@add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING) @add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
class FunnelForMaskedLM(FunnelPreTrainedModel): class FunnelForMaskedLM(FunnelPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FunnelConfig) -> None: def __init__(self, config: FunnelConfig) -> None:
super().__init__(config) super().__init__(config)
......
...@@ -1324,6 +1324,8 @@ class GitModel(GitPreTrainedModel): ...@@ -1324,6 +1324,8 @@ class GitModel(GitPreTrainedModel):
"""GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
) )
class GitForCausalLM(GitPreTrainedModel): class GitForCausalLM(GitPreTrainedModel):
_tied_weights_keys = ["output.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -959,6 +959,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -959,6 +959,7 @@ class GPT2Model(GPT2PreTrainedModel):
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1152,6 +1153,7 @@ input sequence). ...@@ -1152,6 +1153,7 @@ input sequence).
class GPT2DoubleHeadsModel(GPT2PreTrainedModel): class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -723,6 +723,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ...@@ -723,6 +723,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
) )
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -669,6 +669,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ...@@ -669,6 +669,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
r"h\.\d+\.attn\.attention\.bias", r"h\.\d+\.attn\.attention\.bias",
] ]
_keys_to_ignore_on_save = [r"lm_head.weight"] _keys_to_ignore_on_save = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -597,6 +597,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -597,6 +597,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
) )
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -592,6 +592,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): ...@@ -592,6 +592,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
) )
class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "embed_out.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "embed_out.weight"]
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -735,6 +735,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -735,6 +735,7 @@ class GPTJModel(GPTJPreTrainedModel):
) )
class GPTJForCausalLM(GPTJPreTrainedModel): class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment