Unverified Commit 695928e1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Tied params cleanup (#24211)

* First test

* Add info for all models

* style

* Repo consistency

* Fix last model and cleanup prints

* Repo consistency

* Use consistent function for detecting tied weights
parent 3723329d
......@@ -757,6 +757,7 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel):
)
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -2331,6 +2331,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
_keys_to_ignore_on_save = [
r"speecht5.encoder.prenet.pos_sinusoidal_embed.weights",
]
_tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
def __init__(self, config: SpeechT5Config):
super().__init__(config)
......
......@@ -648,6 +648,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1338,6 +1338,7 @@ num_heads)`.
)
class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight", r"decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
......@@ -1510,6 +1511,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
r"decoder.embed_tokens.weight",
r"lm_head.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
......@@ -1823,6 +1825,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
)
class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: SwitchTransformersConfig):
super().__init__(config)
......
......@@ -1329,6 +1329,7 @@ class T5Model(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
......@@ -1533,6 +1534,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
......@@ -1840,6 +1842,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
)
class T5EncoderModel(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
super().__init__(config)
......
......@@ -992,6 +992,7 @@ class TapasModel(TapasPreTrainedModel):
@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
class TapasForMaskedLM(TapasPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
config_class = TapasConfig
base_model_prefix = "tapas"
......
......@@ -1003,6 +1003,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
_tied_weights_keys = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -789,6 +789,7 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
)
class TrOCRForCausalLM(TrOCRPreTrainedModel):
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
_tied_weights_keys = ["output_projection.weight"]
def __init__(self, config):
config = copy.deepcopy(config)
......
......@@ -894,6 +894,7 @@ class ViltPooler(nn.Module):
)
class ViltForMaskedLM(ViltPreTrainedModel):
_keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"]
_tied_weights_keys = ["mlm_score.decoder.weight", "mlm_score.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -872,6 +872,7 @@ class VisualBertModel(VisualBertPreTrainedModel):
)
class VisualBertForPreTraining(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1462,6 +1463,7 @@ class VisualBertRegionToPhraseAttention(nn.Module):
)
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1324,6 +1324,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
_keys_to_ignore_on_save = [
r"proj_out.weight",
]
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: WhisperConfig):
super().__init__(config)
......
......@@ -757,6 +757,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
_keys_to_ignore_on_save = [
r"model.embed_positions.weights",
]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -671,6 +671,7 @@ class XLMPredLayer(nn.Module):
)
class XLMWithLMHeadModel(XLMPreTrainedModel):
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
_tied_weights_keys = ["pred_layer.proj.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1769,6 +1769,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
def __init__(self, config: XLMProphetNetConfig):
super().__init__(config)
......@@ -1903,6 +1904,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
"encoder.word_embeddings.weight",
"lm_head.weight",
]
_tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig):
super().__init__(config)
......@@ -2118,6 +2120,7 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: XLMProphetNetConfig):
# set config for CLM
......
......@@ -888,6 +888,7 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1046,6 +1047,7 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -853,6 +853,7 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......@@ -1003,6 +1004,7 @@ class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -1293,6 +1293,7 @@ class XLNetModel(XLNetPreTrainedModel):
)
class XLNetLMHeadModel(XLNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_loss.weight"]
_tied_weights_keys = ["lm_loss.weight"]
def __init__(self, config):
super().__init__(config)
......
......@@ -992,6 +992,7 @@ class XmodForCausalLM(XmodPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod
def __init__(self, config):
......@@ -1154,6 +1155,7 @@ class XmodForMaskedLM(XmodPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
# Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod
def __init__(self, config):
......
......@@ -854,6 +854,7 @@ class YosoForMaskedLM(YosoPreTrainedModel):
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import gc
import glob
......@@ -22,6 +23,7 @@ import os
import os.path
import pickle
import random
import re
import sys
import tempfile
import unittest
......@@ -127,6 +129,7 @@ if is_torch_available():
T5ForConditionalGeneration,
)
from transformers.modeling_utils import shard_checkpoint
from transformers.pytorch_utils import id_tensor_storage
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
......@@ -1662,6 +1665,33 @@ class ModelTesterMixin:
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.tie_word_embeddings = True
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(tied_params, [])
def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment