"INSTALL/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "99fbd14f08c339e80ccc1fb1d61f6da555671510"
Unverified Commit 443bf5e9 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix safetensors failing tests (#27231)



* Fix Kosmos2

* Fix ProphetNet

* Fix MarianMT

* Fix M4T

* XLM ProphetNet

* ProphetNet fix

* XLM ProphetNet

* Final M4T fixes

* Tied weights keys

* Revert M4T changes

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 4557a0de
...@@ -1755,6 +1755,11 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1755,6 +1755,11 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
self.encoder.word_embeddings = self.word_embeddings self.encoder.word_embeddings = self.word_embeddings
self.decoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings)
self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1876,6 +1881,10 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1876,6 +1881,10 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.prophetnet.word_embeddings return self.prophetnet.word_embeddings
...@@ -2070,7 +2079,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -2070,7 +2079,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
PROPHETNET_START_DOCSTRING, PROPHETNET_START_DOCSTRING,
) )
class ProphetNetForCausalLM(ProphetNetPreTrainedModel): class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = [
"prophetnet.word_embeddings.weight",
"prophetnet.decoder.word_embeddings.weight",
"lm_head.weight",
]
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
# set config for CLM # set config for CLM
...@@ -2100,6 +2113,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2100,6 +2113,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head)
def set_decoder(self, decoder): def set_decoder(self, decoder):
self.prophetnet.decoder = decoder self.prophetnet.decoder = decoder
...@@ -2311,7 +2328,15 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): ...@@ -2311,7 +2328,15 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
def __init__(self, config: ProphetNetConfig): def __init__(self, config: ProphetNetConfig):
super().__init__(config) super().__init__(config)
self.decoder = ProphetNetDecoder(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings)
# Initialize weights and apply final processing
self.post_init()
def _tie_weights(self):
self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings())
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs) return self.decoder(*args, **kwargs)
...@@ -1779,6 +1779,11 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel): ...@@ -1779,6 +1779,11 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
self.encoder.word_embeddings = self.word_embeddings self.encoder.word_embeddings = self.word_embeddings
self.decoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings)
self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -1901,6 +1906,10 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ...@@ -1901,6 +1906,10 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.prophetnet.word_embeddings return self.prophetnet.word_embeddings
...@@ -2098,7 +2107,11 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ...@@ -2098,7 +2107,11 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
) )
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = [
"prophetnet.word_embeddings.weight",
"prophetnet.decoder.word_embeddings.weight",
"lm_head.weight",
]
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
# set config for CLM # set config for CLM
...@@ -2128,6 +2141,10 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): ...@@ -2128,6 +2141,10 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head)
def set_decoder(self, decoder): def set_decoder(self, decoder):
self.prophetnet.decoder = decoder self.prophetnet.decoder = decoder
...@@ -2340,7 +2357,15 @@ class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel): ...@@ -2340,7 +2357,15 @@ class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel):
def __init__(self, config: XLMProphetNetConfig): def __init__(self, config: XLMProphetNetConfig):
super().__init__(config) super().__init__(config)
self.decoder = XLMProphetNetDecoder(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.decoder = XLMProphetNetDecoder(config, word_embeddings=self.word_embeddings)
# Initialize weights and apply final processing
self.post_init()
def _tie_weights(self):
self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings())
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs) return self.decoder(*args, **kwargs)
...@@ -304,6 +304,25 @@ class Kosmos2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -304,6 +304,25 @@ class Kosmos2ModelTest(ModelTesterMixin, unittest.TestCase):
expected_arg_names = ["pixel_values"] expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
def test_load_save_without_tied_weights(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.text_config.tie_word_embeddings = False
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as d:
model.save_pretrained(d)
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking there was no complain of missing weights
self.assertEqual(infos["missing_keys"], [])
# overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers` # overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers`
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
......
...@@ -76,7 +76,7 @@ from transformers.testing_utils import ( ...@@ -76,7 +76,7 @@ from transformers.testing_utils import (
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME, CONFIG_NAME,
GENERATION_CONFIG_NAME, GENERATION_CONFIG_NAME,
WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
is_accelerate_available, is_accelerate_available,
is_flax_available, is_flax_available,
is_tf_available, is_tf_available,
...@@ -91,6 +91,7 @@ if is_accelerate_available(): ...@@ -91,6 +91,7 @@ if is_accelerate_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file from safetensors.torch import save_file as safe_save_file
from torch import nn from torch import nn
...@@ -311,17 +312,20 @@ class ModelTesterMixin: ...@@ -311,17 +312,20 @@ class ModelTesterMixin:
# check that certain keys didn't get saved with the model # check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME) output_model_file = os.path.join(tmpdirname, SAFE_WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file) state_dict_saved = safe_load_file(output_model_file)
for k in _keys_to_ignore_on_save: for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys())) self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys()))
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer. # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
load_result = model.load_state_dict(state_dict_saved, strict=False) load_result = model.load_state_dict(state_dict_saved, strict=False)
self.assertTrue( keys_to_ignore = set(model._keys_to_ignore_on_save)
len(load_result.missing_keys) == 0
or set(load_result.missing_keys) == set(model._keys_to_ignore_on_save) if hasattr(model, "_tied_weights_keys"):
) keys_to_ignore.update(set(model._tied_weights_keys))
self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore)
self.assertTrue(len(load_result.unexpected_keys) == 0) self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_backward_compatibility(self): def test_gradient_checkpointing_backward_compatibility(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment