Unverified Commit d143087d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Making sure we can use safetensors to serialize all the time. (#22437)



* Making sure we can use safetensors to serialize all the time.

* Expanding the tests for increased coverage.

* Update the test.

* Getting current state of affairs.

* Tentative fix.

* Fixing black version.

* Fixing the worst offenders.

* Try to modify less files.

* Fixing blip_2 (Weird solution right now).

* Fixing deta.

* Fix blip ?

* Missing extra newline.

* No deta modification.

* Adding some comments.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Addressing comments.

* Addressing comments.

* creating warn_once.

* Warning_once !

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 516077b3
...@@ -1736,6 +1736,41 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1736,6 +1736,41 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for ignore_key in self._keys_to_ignore_on_save: for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys(): if ignore_key in state_dict.keys():
del state_dict[ignore_key] del state_dict[ignore_key]
if safe_serialization:
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
ptrs[tensor.data_ptr()].append(name)
# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._keys_to_ignore_on_load_missing is not None:
for name in names:
matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing)
if matches_pattern and name in state_dict:
del state_dict[name]
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)
# Shard the model if it is too big. # Shard the model if it is too big.
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
...@@ -2813,6 +2848,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2813,6 +2848,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = list(set(expected_keys) - set(loaded_keys)) missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Some tensors maybe have been already filled by another key (tied weights).
existing_ptrs = {model_state_dict[k].data_ptr() for k in loaded_keys if k in model_state_dict}
missing_keys = [
k for k in missing_keys if k in model_state_dict and model_state_dict[k].data_ptr() not in existing_ptrs
]
# Some models may have keys that are not in the state by design, removing them before needlessly warning # Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user. # the user.
if cls._keys_to_ignore_on_load_missing is not None: if cls._keys_to_ignore_on_load_missing is not None:
......
...@@ -1238,8 +1238,28 @@ class Blip2Model(Blip2PreTrainedModel): ...@@ -1238,8 +1238,28 @@ class Blip2Model(Blip2PreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self):
return self.vision_model.embeddings.patch_embedding return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()
def get_encoder(self):
return self.language_model.get_encoder()
def get_decoder(self):
return self.language_model.get_decoder()
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
def get_text_features( def get_text_features(
......
...@@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput): ...@@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput):
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([module for i in range(N)])
def inverse_sigmoid(x, eps=1e-5): def inverse_sigmoid(x, eps=1e-5):
......
...@@ -609,8 +609,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -609,8 +609,6 @@ class LlamaModel(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel): class LlamaForCausalLM(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model = LlamaModel(config) self.model = LlamaModel(config)
......
...@@ -357,9 +357,10 @@ class Pix2StructConfig(PretrainedConfig): ...@@ -357,9 +357,10 @@ class Pix2StructConfig(PretrainedConfig):
initializer_factor=1.0, initializer_factor=1.0,
initializer_range=0.02, initializer_range=0.02,
is_vqa=False, is_vqa=False,
tie_word_embeddings=False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
if text_config is None: if text_config is None:
text_config = {} text_config = {}
......
...@@ -27,6 +27,7 @@ import tempfile ...@@ -27,6 +27,7 @@ import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import warnings import warnings
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -1626,6 +1627,41 @@ class ModelTesterMixin: ...@@ -1626,6 +1627,41 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
model_tied.save_pretrained(d, safe_serialization=True)
except Exception as e:
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model_tied.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
ptrs[v.data_ptr()].append(k)
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
for _, shared_names in shared_ptrs.items():
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
self.assertEqual(
len(reloaded_ptrs),
1,
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_tied_model_weights_key_ignore(self): def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment