"docs/vscode:/vscode.git/clone" did not exist on "c33f6046c3dab8f41bedf893404e6469dea3bce8"
Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
...@@ -1069,6 +1069,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1069,6 +1069,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
# trained, but which are either deterministic or tied variables) # trained, but which are either deterministic or tied variables)
_keys_to_ignore_on_save = None _keys_to_ignore_on_save = None
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
_tied_weights_keys = None
is_parallelizable = False is_parallelizable = False
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
...@@ -1778,8 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1778,8 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# We're going to remove aliases before saving # We're going to remove aliases before saving
ptrs = collections.defaultdict(list) ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items(): for name, tensor in state_dict.items():
ident = (tensor.data_ptr(), tensor.device, tensor.shape, tensor.stride()) ptrs[id_tensor_storage(tensor)].append(name)
ptrs[ident].append(name)
# These are all the pointers of shared tensors. # These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
......
...@@ -759,6 +759,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -759,6 +759,7 @@ class AlbertModel(AlbertPreTrainedModel):
ALBERT_START_DOCSTRING, ALBERT_START_DOCSTRING,
) )
class AlbertForPreTraining(AlbertPreTrainedModel): class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
"predictions.decoder.weight", "predictions.decoder.weight",
"predictions.decoder.bias", "predictions.decoder.bias",
...@@ -912,6 +913,7 @@ class AlbertSOPHead(nn.Module): ...@@ -912,6 +913,7 @@ class AlbertSOPHead(nn.Module):
) )
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
"predictions.decoder.weight", "predictions.decoder.weight",
"predictions.decoder.bias", "predictions.decoder.bias",
......
...@@ -1166,6 +1166,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1166,6 +1166,7 @@ class BartDecoder(BartPretrainedModel):
) )
class BartModel(BartPretrainedModel): class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
...@@ -1293,9 +1294,10 @@ class BartModel(BartPretrainedModel): ...@@ -1293,9 +1294,10 @@ class BartModel(BartPretrainedModel):
) )
class BartForConditionalGeneration(BartPretrainedModel): class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", "final_logits_bias",
r"lm_head.weight", "lm_head.weight",
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
] ]
...@@ -1472,6 +1474,7 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1472,6 +1474,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
) )
class BartForSequenceClassification(BartPretrainedModel): class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
...@@ -1602,6 +1605,7 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1602,6 +1605,7 @@ class BartForSequenceClassification(BartPretrainedModel):
) )
class BartForQuestionAnswering(BartPretrainedModel): class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1740,6 +1744,7 @@ class BartDecoderWrapper(BartPretrainedModel): ...@@ -1740,6 +1744,7 @@ class BartDecoderWrapper(BartPretrainedModel):
) )
class BartForCausalLM(BartPretrainedModel): class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1054,6 +1054,7 @@ class BertModel(BertPreTrainedModel): ...@@ -1054,6 +1054,7 @@ class BertModel(BertPreTrainedModel):
) )
class BertForPreTraining(BertPreTrainedModel): class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1161,6 +1162,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -1161,6 +1162,7 @@ class BertForPreTraining(BertPreTrainedModel):
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1301,6 +1303,7 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1301,6 +1303,7 @@ class BertLMHeadModel(BertPreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -861,6 +861,7 @@ class BertGenerationOnlyLMHead(nn.Module): ...@@ -861,6 +861,7 @@ class BertGenerationOnlyLMHead(nn.Module):
) )
class BertGenerationDecoder(BertGenerationPreTrainedModel): class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"] _keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -2262,6 +2262,7 @@ class BigBirdModel(BigBirdPreTrainedModel): ...@@ -2262,6 +2262,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
class BigBirdForPreTraining(BigBirdPreTrainedModel): class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -2368,6 +2369,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel): ...@@ -2368,6 +2369,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING) @add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel): class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -2517,6 +2519,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): ...@@ -2517,6 +2519,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
"cls.predictions.decoder.weight", "cls.predictions.decoder.weight",
"cls.predictions.decoder.bias", "cls.predictions.decoder.bias",
] ]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -2354,6 +2354,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2354,6 +2354,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
) )
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig): def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config) super().__init__(config)
...@@ -2484,9 +2485,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): ...@@ -2484,9 +2485,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [ _keys_to_ignore_on_load_missing = [
r"final_logits_bias", "final_logits_bias",
r"lm_head.weight", "lm_head.weight",
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
] ]
...@@ -2663,6 +2665,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ...@@ -2663,6 +2665,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
) )
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig, **kwargs): def __init__(self, config: BigBirdPegasusConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
...@@ -2792,6 +2795,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2792,6 +2795,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
) )
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -2924,6 +2928,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel): ...@@ -2924,6 +2928,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -647,6 +647,7 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -647,6 +647,7 @@ class BioGptModel(BioGptPreTrainedModel):
) )
class BioGptForCausalLM(BioGptPreTrainedModel): class BioGptForCausalLM(BioGptPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"] _keys_to_ignore_on_load_missing = ["output_projection.weight"]
_tied_weights_keys = ["output_projection.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1098,6 +1098,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1098,6 +1098,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
) )
class BlenderbotModel(BlenderbotPreTrainedModel): class BlenderbotModel(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__(config) super().__init__(config)
...@@ -1246,6 +1247,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1246,6 +1247,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotConfig): def __init__(self, config: BlenderbotConfig):
super().__init__(config) super().__init__(config)
...@@ -1435,6 +1437,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel): ...@@ -1435,6 +1437,7 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel): class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -1092,6 +1092,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1092,6 +1092,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
) )
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config) super().__init__(config)
...@@ -1228,6 +1229,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1228,6 +1229,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
"encoder.embed_tokens.weight", "encoder.embed_tokens.weight",
"decoder.embed_tokens.weight", "decoder.embed_tokens.weight",
] ]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: BlenderbotSmallConfig): def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config) super().__init__(config)
...@@ -1402,6 +1404,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel): ...@@ -1402,6 +1404,7 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -930,6 +930,7 @@ class BlipModel(BlipPreTrainedModel): ...@@ -930,6 +930,7 @@ class BlipModel(BlipPreTrainedModel):
class BlipForConditionalGeneration(BlipPreTrainedModel): class BlipForConditionalGeneration(BlipPreTrainedModel):
config_class = BlipConfig config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values" main_input_name = "pixel_values"
def __init__(self, config: BlipConfig): def __init__(self, config: BlipConfig):
...@@ -1102,6 +1103,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel): ...@@ -1102,6 +1103,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
class BlipForQuestionAnswering(BlipPreTrainedModel): class BlipForQuestionAnswering(BlipPreTrainedModel):
config_class = BlipConfig config_class = BlipConfig
_keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
def __init__(self, config: BlipConfig): def __init__(self, config: BlipConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1232,6 +1232,11 @@ class Blip2Model(Blip2PreTrainedModel): ...@@ -1232,6 +1232,11 @@ class Blip2Model(Blip2PreTrainedModel):
language_model = AutoModelForCausalLM.from_config(config.text_config) language_model = AutoModelForCausalLM.from_config(config.text_config)
else: else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model self.language_model = language_model
# Initialize weights and apply final processing # Initialize weights and apply final processing
...@@ -1587,6 +1592,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): ...@@ -1587,6 +1592,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
language_model = AutoModelForCausalLM.from_config(config.text_config) language_model = AutoModelForCausalLM.from_config(config.text_config)
else: else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model self.language_model = language_model
# Initialize weights and apply final processing # Initialize weights and apply final processing
......
...@@ -827,6 +827,7 @@ class BloomModel(BloomPreTrainedModel): ...@@ -827,6 +827,7 @@ class BloomModel(BloomPreTrainedModel):
) )
class BloomForCausalLM(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: BloomConfig): def __init__(self, config: BloomConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1265,6 +1265,12 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel): ...@@ -1265,6 +1265,12 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
self.post_init() self.post_init()
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BRIDGETOWER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BridgeTowerModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BridgeTowerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
...@@ -1548,6 +1554,8 @@ class BridgeTowerITMHead(nn.Module): ...@@ -1548,6 +1554,8 @@ class BridgeTowerITMHead(nn.Module):
BRIDGETOWER_START_DOCSTRING, BRIDGETOWER_START_DOCSTRING,
) )
class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
_tied_weights_keys = ["mlm_score.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -938,6 +938,7 @@ class CamembertForMaskedLM(CamembertPreTrainedModel): ...@@ -938,6 +938,7 @@ class CamembertForMaskedLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1433,6 +1434,7 @@ class CamembertForCausalLM(CamembertPreTrainedModel): ...@@ -1433,6 +1434,7 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -601,6 +601,7 @@ class CodeGenModel(CodeGenPreTrainedModel): ...@@ -601,6 +601,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
) )
class CodeGenForCausalLM(CodeGenPreTrainedModel): class CodeGenForCausalLM(CodeGenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.causal_mask"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.causal_mask"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -881,6 +881,7 @@ class ConvBertGeneratorPredictions(nn.Module): ...@@ -881,6 +881,7 @@ class ConvBertGeneratorPredictions(nn.Module):
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING) @add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
class ConvBertForMaskedLM(ConvBertPreTrainedModel): class ConvBertForMaskedLM(ConvBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"] _keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]
_tied_weights_keys = ["generator.lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -750,6 +750,7 @@ class CpmAntModel(CpmAntPreTrainedModel): ...@@ -750,6 +750,7 @@ class CpmAntModel(CpmAntPreTrainedModel):
) )
class CpmAntForCausalLM(CpmAntPreTrainedModel): class CpmAntForCausalLM(CpmAntPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: CpmAntConfig): def __init__(self, config: CpmAntConfig):
super().__init__(config) super().__init__(config)
......
...@@ -510,6 +510,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -510,6 +510,7 @@ class CTRLModel(CTRLPreTrainedModel):
) )
class CTRLLMHeadModel(CTRLPreTrainedModel): class CTRLLMHeadModel(CTRLPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -886,6 +886,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): ...@@ -886,6 +886,7 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1040,6 +1041,7 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): ...@@ -1040,6 +1041,7 @@ class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment