Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
...@@ -757,6 +757,7 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel): ...@@ -757,6 +757,7 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel):
) )
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -2331,6 +2331,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): ...@@ -2331,6 +2331,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = [
r"speecht5.encoder.prenet.pos_sinusoidal_embed.weights", r"speecht5.encoder.prenet.pos_sinusoidal_embed.weights",
] ]
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
def __init__(self, config: SpeechT5Config): def __init__(self, config: SpeechT5Config):
super().__init__(config) super().__init__(config)
......
...@@ -648,6 +648,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): ...@@ -648,6 +648,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
"cls.predictions.decoder.weight", "cls.predictions.decoder.weight",
"embeddings.position_ids", "embeddings.position_ids",
] ]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1338,6 +1338,7 @@ num_heads)`. ...@@ -1338,6 +1338,7 @@ num_heads)`.
) )
class SwitchTransformersModel(SwitchTransformersPreTrainedModel): class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight", r"decoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight", r"decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig): def __init__(self, config: SwitchTransformersConfig):
super().__init__(config) super().__init__(config)
...@@ -1510,6 +1511,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod ...@@ -1510,6 +1511,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
r"decoder.embed_tokens.weight", r"decoder.embed_tokens.weight",
r"lm_head.weight", r"lm_head.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: SwitchTransformersConfig): def __init__(self, config: SwitchTransformersConfig):
super().__init__(config) super().__init__(config)
...@@ -1823,6 +1825,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod ...@@ -1823,6 +1825,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
) )
class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig): def __init__(self, config: SwitchTransformersConfig):
super().__init__(config) super().__init__(config)
......
...@@ -1329,6 +1329,7 @@ class T5Model(T5PreTrainedModel): ...@@ -1329,6 +1329,7 @@ class T5Model(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__(config) super().__init__(config)
...@@ -1533,6 +1534,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1533,6 +1534,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
] ]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__(config) super().__init__(config)
...@@ -1840,6 +1842,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1840,6 +1842,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
) )
class T5EncoderModel(T5PreTrainedModel): class T5EncoderModel(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__(config) super().__init__(config)
......
...@@ -992,6 +992,7 @@ class TapasModel(TapasPreTrainedModel): ...@@ -992,6 +992,7 @@ class TapasModel(TapasPreTrainedModel):
@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING) @add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
class TapasForMaskedLM(TapasPreTrainedModel): class TapasForMaskedLM(TapasPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
config_class = TapasConfig config_class = TapasConfig
base_model_prefix = "tapas" base_model_prefix = "tapas"
......
...@@ -1003,6 +1003,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1003,6 +1003,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
) )
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"] _keys_to_ignore_on_load_missing = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
_tied_weights_keys = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -789,6 +789,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel): ...@@ -789,6 +789,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
) )
class TrOCRForCausalLM(TrOCRPreTrainedModel): class TrOCRForCausalLM(TrOCRPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"] _keys_to_ignore_on_load_missing = ["output_projection.weight"]
_tied_weights_keys = ["output_projection.weight"]
def __init__(self, config): def __init__(self, config):
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -894,6 +894,7 @@ class ViltPooler(nn.Module): ...@@ -894,6 +894,7 @@ class ViltPooler(nn.Module):
) )
class ViltForMaskedLM(ViltPreTrainedModel): class ViltForMaskedLM(ViltPreTrainedModel):
_keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"] _keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"]
_tied_weights_keys = ["mlm_score.decoder.weight", "mlm_score.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -872,6 +872,7 @@ class VisualBertModel(VisualBertPreTrainedModel): ...@@ -872,6 +872,7 @@ class VisualBertModel(VisualBertPreTrainedModel):
) )
class VisualBertForPreTraining(VisualBertPreTrainedModel): class VisualBertForPreTraining(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1462,6 +1463,7 @@ class VisualBertRegionToPhraseAttention(nn.Module): ...@@ -1462,6 +1463,7 @@ class VisualBertRegionToPhraseAttention(nn.Module):
) )
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel): class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"] _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1324,6 +1324,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1324,6 +1324,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = [
r"proj_out.weight", r"proj_out.weight",
] ]
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: WhisperConfig): def __init__(self, config: WhisperConfig):
super().__init__(config) super().__init__(config)
......
...@@ -757,6 +757,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel): ...@@ -757,6 +757,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
_keys_to_ignore_on_save = [ _keys_to_ignore_on_save = [
r"model.embed_positions.weights", r"model.embed_positions.weights",
] ]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -671,6 +671,7 @@ class XLMPredLayer(nn.Module): ...@@ -671,6 +671,7 @@ class XLMPredLayer(nn.Module):
) )
class XLMWithLMHeadModel(XLMPreTrainedModel): class XLMWithLMHeadModel(XLMPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"] _keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
_tied_weights_keys = ["pred_layer.proj.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1769,6 +1769,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1769,6 +1769,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetModel(XLMProphetNetPreTrainedModel): class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"] _keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
super().__init__(config) super().__init__(config)
...@@ -1903,6 +1904,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ...@@ -1903,6 +1904,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
"encoder.word_embeddings.weight", "encoder.word_embeddings.weight",
"lm_head.weight", "lm_head.weight",
] ]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
super().__init__(config) super().__init__(config)
...@@ -2118,6 +2120,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ...@@ -2118,6 +2120,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"] _keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
# set config for CLM # set config for CLM
......
...@@ -888,6 +888,7 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): ...@@ -888,6 +888,7 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1046,6 +1047,7 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel): ...@@ -1046,6 +1047,7 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -853,6 +853,7 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): ...@@ -853,6 +853,7 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1003,6 +1004,7 @@ class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): ...@@ -1003,6 +1004,7 @@ class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -1293,6 +1293,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1293,6 +1293,7 @@ class XLNetModel(XLNetPreTrainedModel):
) )
class XLNetLMHeadModel(XLNetPreTrainedModel): class XLNetLMHeadModel(XLNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_loss.weight"] _keys_to_ignore_on_load_missing = [r"lm_loss.weight"]
_tied_weights_keys = ["lm_loss.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -992,6 +992,7 @@ class XmodForCausalLM(XmodPreTrainedModel): ...@@ -992,6 +992,7 @@ class XmodForCausalLM(XmodPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod
def __init__(self, config): def __init__(self, config):
...@@ -1154,6 +1155,7 @@ class XmodForMaskedLM(XmodPreTrainedModel): ...@@ -1154,6 +1155,7 @@ class XmodForMaskedLM(XmodPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod
def __init__(self, config): def __init__(self, config):
......
...@@ -854,6 +854,7 @@ class YosoForMaskedLM(YosoPreTrainedModel): ...@@ -854,6 +854,7 @@ class YosoForMaskedLM(YosoPreTrainedModel):
"cls.predictions.decoder.weight", "cls.predictions.decoder.weight",
"embeddings.position_ids", "embeddings.position_ids",
] ]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import copy import copy
import gc import gc
import glob import glob
...@@ -22,6 +23,7 @@ import os ...@@ -22,6 +23,7 @@ import os
import os.path import os.path
import pickle import pickle
import random import random
import re
import sys import sys
import tempfile import tempfile
import unittest import unittest
...@@ -127,6 +129,7 @@ if is_torch_available(): ...@@ -127,6 +129,7 @@ if is_torch_available():
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from transformers.pytorch_utils import id_tensor_storage
# Fake pretrained models for tests # Fake pretrained models for tests
class BaseModel(PreTrainedModel): class BaseModel(PreTrainedModel):
...@@ -1662,6 +1665,33 @@ class ModelTesterMixin: ...@@ -1662,6 +1665,33 @@ class ModelTesterMixin:
f"The shared pointers are incorrect, found different pointers for keys {shared_names}", f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
) )
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = True
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(tied_params, [])
def test_tied_model_weights_key_ignore(self): def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment