Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
......@@ -1069,6 +1069,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
# trained, but which are either deterministic or tied variables)
_keys_to_ignore_on_save = None
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
_tied_weights_keys = None
is_parallelizable = False
supports_gradient_checkpointing = False
......@@ -1778,8 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ident = (tensor.data_ptr(), tensor.device, tensor.shape, tensor.stride())
ptrs[ident].append(name)
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
......
......@@ -759,6 +759,7 @@ class AlbertModel(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
......@@ -912,6 +913,7 @@ class AlbertSOPHead(nn.Module):
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
......
......@@ -1166,6 +1166,7 @@ class BartDecoder(BartPretrainedModel):
)
class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig):
super().__init__(config)
......@@ -1293,9 +1294,10 @@ class BartModel(BartPretrainedModel):
)
class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"lm_head.weight",
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
......@@ -1472,6 +1474,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
)
class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
......@@ -1602,6 +1605,7 @@ class BartForSequenceClassification(BartPretrainedModel):
)
class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1740,6 +1744,7 @@ class BartDecoderWrapper(BartPretrainedModel):
)
class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1054,6 +1054,7 @@ class BertModel(BertPreTrainedModel):
)
class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1161,6 +1162,7 @@ class BertForPreTraining(BertPreTrainedModel):
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -1301,6 +1303,7 @@ class BertLMHeadModel(BertPreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -861,6 +861,7 @@ class BertGenerationOnlyLMHead(nn.Module):
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -2262,6 +2262,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -2368,6 +2369,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -2517,6 +2519,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -2354,6 +2354,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
)
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
......@@ -2484,9 +2485,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"lm_head.weight",
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
......@@ -2663,6 +2665,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
)
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
super().__init__(config, **kwargs)
......@@ -2792,6 +2795,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
)
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
......@@ -2924,6 +2928,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -647,6 +647,7 @@ class BioGptModel(BioGptPreTrainedModel):
)
class BioGptForCausalLM(BioGptPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
_tied_weights_keys = ["output_projection.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1098,6 +1098,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
)
class BlenderbotModel(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
......@@ -1246,6 +1247,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
"decoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
......@@ -1435,6 +1437,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -1092,6 +1092,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
)
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
......@@ -1228,6 +1229,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
......@@ -1402,6 +1404,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -930,6 +930,7 @@ class BlipModel(BlipPreTrainedModel):
class BlipForConditionalGeneration(BlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values"
def __init__(self, config: BlipConfig):
......@@ -1102,6 +1103,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
class BlipForQuestionAnswering(BlipPreTrainedModel):
config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
def __init__(self, config: BlipConfig):
super().__init__(config)
......
......@@ -1232,6 +1232,11 @@ class Blip2Model(Blip2PreTrainedModel):
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model
# Initialize weights and apply final processing
......@@ -1587,6 +1592,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model
# Initialize weights and apply final processing
......
......@@ -827,6 +827,7 @@ class BloomModel(BloomPreTrainedModel):
)
class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: BloomConfig):
super().__init__(config)
......
......@@ -1265,6 +1265,12 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
self.post_init()
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BridgeTowerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
......@@ -1548,6 +1554,8 @@ class BridgeTowerITMHead(nn.Module):
BRIDGETOWER_START_DOCSTRING,
)
class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
_tied_weights_keys = ["mlm_score.decoder.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -938,6 +938,7 @@ class CamembertForMaskedLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1433,6 +1434,7 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -601,6 +601,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
)
class CodeGenForCausalLM(CodeGenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.causal_mask"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -881,6 +881,7 @@ class ConvBertGeneratorPredictions(nn.Module):
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class ConvBertForMaskedLM(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]
_tied_weights_keys = ["generator.lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -750,6 +750,7 @@ class CpmAntModel(CpmAntPreTrainedModel):
)
class CpmAntForCausalLM(CpmAntPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: CpmAntConfig):
super().__init__(config)
......
......@@ -510,6 +510,7 @@ class CTRLModel(CTRLPreTrainedModel):
)
class CTRLLMHeadModel(CTRLPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -886,6 +886,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1040,6 +1041,7 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment