Unverified Commit 47500b1d authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix TF loading PT safetensors when weights are tied (#27490)



* Un-skip tests

* Add aliasing support to tf_to_pt_weight_rename

* Refactor tf-to-pt weight rename for simplicity

* Patch mobilebert

* Let us pray that the transfo-xl one works

* Add XGLM rename

* Expand the test to see if we can get more models to break

* Expand the test to see if we can get more models to break

* Fix MPNet (it was actually an unrelated bug)

* Fix MPNet (it was actually an unrelated bug)

* Add speech2text fix

* Update src/transformers/modeling_tf_pytorch_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/mobilebert/modeling_tf_mobilebert.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update to always return a tuple from tf_to_pt_weight_rename

* reformat

* Add a couple of missing tuples

* Remove the extra test for tie_word_embeddings since it didn't cause any unexpected failures anyway

* Revert changes to modeling_tf_mpnet.py

* Skip MPNet test and add explanation

* Add weight link for BART

* Add TODO to clean this up a bit

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9f1f11a2
...@@ -318,7 +318,14 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -318,7 +318,14 @@ def load_pytorch_state_dict_in_tf2_model(
name_scope=_prefix, name_scope=_prefix,
) )
if tf_to_pt_weight_rename is not None: if tf_to_pt_weight_rename is not None:
name = tf_to_pt_weight_rename(name) aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing
for alias in aliases: # The aliases are in priority order, take the first one that matches
if alias in tf_keys_to_pt_keys:
name = alias
break
else:
# If none of the aliases match, just use the first one (it'll be reported as missing)
name = aliases[0]
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
if name not in tf_keys_to_pt_keys: if name not in tf_keys_to_pt_keys:
......
...@@ -2892,6 +2892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2892,6 +2892,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Instantiate model. # Instantiate model.
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"):
# TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method
# to be defined for each class that requires a rename. We can probably just have a class-level
# dict and a single top-level method or something and cut down a lot of boilerplate code
tf_to_pt_weight_rename = model.tf_to_pt_weight_rename
if from_pt: if from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
......
...@@ -494,6 +494,12 @@ class TFBartPretrainedModel(TFPreTrainedModel): ...@@ -494,6 +494,12 @@ class TFBartPretrainedModel(TFPreTrainedModel):
dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2 dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2
return dummy_inputs return dummy_inputs
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "model.shared.weight":
return tf_weight, "model.decoder.embed_tokens.weight"
else:
return (tf_weight,)
BART_START_DOCSTRING = r""" BART_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
......
...@@ -987,6 +987,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -987,6 +987,21 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return inputs return inputs
# Adapted from the torch tie_weights function
def tf_to_pt_weight_rename(self, tf_weight):
if self.config.tie_word_embeddings and "crit.out_layers" in tf_weight:
return tf_weight, tf_weight.replace("crit.out_layers", "transformer.word_emb.emb_layers")
elif self.config.tie_projs and "crit.out_projs" in tf_weight:
for i, tie_proj in enumerate(self.config.tie_projs):
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
# self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
return tf_weight, tf_weight.replace(f"crit.out_projs.{i}", "transformer.word_emb.emb_projs.0")
elif tie_proj and self.config.div_val != 1:
# self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
return tf_weight, tf_weight.replace("crit.out_projs", "transformer.word_emb.emb_projs")
else:
return (tf_weight,)
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -291,16 +291,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -291,16 +291,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings) return self.decoder.set_output_embeddings(new_embeddings)
@classmethod def tf_to_pt_weight_rename(self, tf_weight):
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Example:
```python
>>> from transformers import TFEncoderDecoderModel
>>> model = TFEncoderDecoderModel.from_pretrained("ydshieh/bert2bert-cnn_dailymail-fp16")
```"""
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
...@@ -311,18 +302,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -311,18 +302,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
# often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
# Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
# or not. # or not.
encoder_model_type = self.config.encoder.model_type
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
encoder_model_type = config.encoder.model_type
def tf_to_pt_weight_rename(tf_weight):
if "encoder" in tf_weight and "decoder" not in tf_weight: if "encoder" in tf_weight and "decoder" not in tf_weight:
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
else: else:
return tf_weight return (tf_weight,)
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod
def from_encoder_decoder_pretrained( def from_encoder_decoder_pretrained(
......
...@@ -1088,6 +1088,12 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTra ...@@ -1088,6 +1088,12 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTra
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "cls.predictions.decoder.weight":
return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
else:
return (tf_weight,)
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING) @add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING)
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
...@@ -1168,6 +1174,12 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel ...@@ -1168,6 +1174,12 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "cls.predictions.decoder.weight":
return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
else:
return (tf_weight,)
class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer): class TFMobileBertOnlyNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
......
...@@ -1460,3 +1460,9 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1460,3 +1460,9 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "lm_head.weight":
return tf_weight, "model.decoder.embed_tokens.weight"
else:
return (tf_weight,)
...@@ -290,33 +290,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -290,33 +290,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings) return self.decoder.set_output_embeddings(new_embeddings)
@classmethod def tf_to_pt_weight_rename(self, tf_weight):
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Example:
```python
>>> from transformers import TFVisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer
>>> from PIL import Image
>>> import requests
>>> image_processor = AutoImageProcessor.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> decoder_tokenizer = AutoTokenizer.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> model = TFVisionEncoderDecoderModel.from_pretrained("ydshieh/vit-gpt2-coco-en")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> img = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = image_processor(images=img, return_tensors="tf").pixel_values # Batch size 1
>>> output_ids = model.generate(
... pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True
... ).sequences
>>> preds = decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
>>> preds = [pred.strip() for pred in preds]
>>> assert preds == ["a cat laying on top of a couch next to another cat"]
```"""
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
# However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
...@@ -327,18 +301,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -327,18 +301,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
# often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
# Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
# or not. # or not.
encoder_model_type = self.config.encoder.model_type
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
encoder_model_type = config.encoder.model_type
def tf_to_pt_weight_rename(tf_weight):
if "encoder" in tf_weight and "decoder" not in tf_weight: if "encoder" in tf_weight and "decoder" not in tf_weight:
return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
else: else:
return tf_weight return (tf_weight,)
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@classmethod @classmethod
def from_encoder_decoder_pretrained( def from_encoder_decoder_pretrained(
......
...@@ -227,15 +227,10 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel): ...@@ -227,15 +227,10 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale") self.logit_scale = self.add_weight(shape=(1,), initializer=initializer, name="logit_scale")
super().build(input_shape) super().build(input_shape)
@classmethod def tf_to_pt_weight_rename(self, tf_weight):
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
# (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
# However, the name of that extra layer is the name of the MainLayer in the base model. # However, the name of that extra layer is the name of the MainLayer in the base model.
if kwargs.get("from_pt", False):
def tf_to_pt_weight_rename(tf_weight):
if "vision_model" in tf_weight: if "vision_model" in tf_weight:
if tf_weight.count("vision_model") == 1: if tf_weight.count("vision_model") == 1:
return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight) return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight)
...@@ -249,10 +244,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel): ...@@ -249,10 +244,7 @@ class TFVisionTextDualEncoderModel(TFPreTrainedModel):
elif "text_model" in tf_weight: elif "text_model" in tf_weight:
return re.sub(r"text_model\..*?\.", "text_model.", tf_weight) return re.sub(r"text_model\..*?\.", "text_model.", tf_weight)
else: else:
return tf_weight return (tf_weight,)
kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING)
def get_text_features( def get_text_features(
......
...@@ -924,3 +924,9 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -924,3 +924,9 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions, cross_attentions=outputs.cross_attentions,
) )
def tf_to_pt_weight_rename(self, tf_weight):
if tf_weight == "lm_head.weight":
return tf_weight, "model.embed_tokens.weight"
else:
return (tf_weight,)
...@@ -302,10 +302,6 @@ class MobileBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -302,10 +302,6 @@ class MobileBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
super().test_resize_tokens_embeddings() super().test_resize_tokens_embeddings()
@unittest.skip("This test is currently broken because of safetensors.")
def test_tf_from_pt_safetensors(self):
pass
def setUp(self): def setUp(self):
self.model_tester = MobileBertModelTester(self) self.model_tester = MobileBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=MobileBertConfig, hidden_size=37)
......
...@@ -246,7 +246,7 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -246,7 +246,7 @@ class MPNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mpnet_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_mpnet_for_question_answering(*config_and_inputs)
@unittest.skip("This isn't passing but should, seems like a misconfiguration of tied weights.") @unittest.skip("TFMPNet adds poolers to all models, unlike the PT model class.")
def test_tf_from_pt_safetensors(self): def test_tf_from_pt_safetensors(self):
return return
......
...@@ -196,10 +196,6 @@ class Speech2Text2StandaloneDecoderModelTest( ...@@ -196,10 +196,6 @@ class Speech2Text2StandaloneDecoderModelTest(
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip("This test is currently broken because of safetensors.")
def test_tf_from_pt_safetensors(self):
pass
# speech2text2 has no base model # speech2text2 has no base model
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -357,10 +357,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -357,10 +357,6 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_model_parallelism(self): def test_model_parallelism(self):
super().test_model_parallelism() super().test_model_parallelism()
@unittest.skip("This test is currently broken because of safetensors.")
def test_tf_from_pt_safetensors(self):
pass
@require_torch @require_torch
class XGLMModelLanguageGenerationTest(unittest.TestCase): class XGLMModelLanguageGenerationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment