Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
......@@ -1022,6 +1022,7 @@ class DebertaModel(DebertaPreTrainedModel):
class DebertaForMaskedLM(DebertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1122,6 +1122,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1824,6 +1824,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
......
......@@ -1776,6 +1776,7 @@ class DetaModel(DetaPreTrainedModel):
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"]
_tied_weights_keys = [r"bbox_embed\.\d+"]
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
def __init__(self, config: DetaConfig):
......
......@@ -596,6 +596,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
_tied_weights_keys = ["vocab_projector.weight"]
def __init__(self, config: PretrainedConfig):
super().__init__(config)
......
......@@ -1167,6 +1167,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
)
class ElectraForMaskedLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1534,6 +1535,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
)
class ElectraForCausalLM(ElectraPreTrainedModel):
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
_tied_weights_keys = ["generator_lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -984,6 +984,7 @@ class ErnieModel(ErniePreTrainedModel):
)
class ErnieForPreTraining(ErniePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config):
......@@ -1096,6 +1097,7 @@ class ErnieForPreTraining(ErniePreTrainedModel):
class ErnieForCausalLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
def __init__(self, config):
......@@ -1243,6 +1245,7 @@ class ErnieForCausalLM(ErniePreTrainedModel):
class ErnieForMaskedLM(ErniePreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
def __init__(self, config):
......
......@@ -962,6 +962,7 @@ class EsmModel(EsmPreTrainedModel):
class EsmForMaskedLM(EsmPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", "lm_head.decoder.weight"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -655,6 +655,7 @@ class FlaubertModel(FlaubertPreTrainedModel):
# Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
class FlaubertWithLMHeadModel(FlaubertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
_tied_weights_keys = ["pred_layer.proj.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1730,6 +1730,12 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
"mlm_head.decoder.bias",
"mim_head.decoder.bias",
]
_tied_weights_keys = [
"mmm_text_head.decoder.bias",
"mmm_image_head.decoder.bias",
"mlm_head.decoder.bias",
"mim_head.decoder.bias",
]
def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
super().__init__(config)
......
......@@ -622,6 +622,7 @@ class FNetModel(FNetPreTrainedModel):
)
class FNetForPreTraining(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -716,6 +717,7 @@ class FNetForPreTraining(FNetPreTrainedModel):
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
class FNetForMaskedLM(FNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1035,6 +1035,7 @@ def _get_shape(t):
)
class FSMTModel(PretrainedFSMTModel):
_keys_to_ignore_on_load_missing = ["decoder.output_projection.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight"]
def __init__(self, config: FSMTConfig):
super().__init__(config)
......@@ -1180,6 +1181,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
"model.encoder.embed_positions.weight",
"model.decoder.embed_positions.weight",
]
_tied_weights_keys = ["model.decoder.embed_tokens.weight"]
def __init__(self, config: FSMTConfig):
super().__init__(config)
......
......@@ -1191,6 +1191,7 @@ class FunnelForPreTraining(FunnelPreTrainedModel):
@add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
class FunnelForMaskedLM(FunnelPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FunnelConfig) -> None:
super().__init__(config)
......
......@@ -1324,6 +1324,8 @@ class GitModel(GitPreTrainedModel):
"""GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
)
class GitForCausalLM(GitPreTrainedModel):
_tied_weights_keys = ["output.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -959,6 +959,7 @@ class GPT2Model(GPT2PreTrainedModel):
class GPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1152,6 +1153,7 @@ input sequence).
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -723,6 +723,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
)
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -669,6 +669,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
r"h\.\d+\.attn\.attention\.bias",
]
_keys_to_ignore_on_save = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -597,6 +597,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
)
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -592,6 +592,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
)
class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "embed_out.weight"]
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -735,6 +735,7 @@ class GPTJModel(GPTJPreTrainedModel):
)
class GPTJForCausalLM(GPTJPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment