Unverified Commit 9b0a8ea7 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Hard error when ignoring tensors. (#27484) (#29906)



* Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Add small tests.

* Dead variable.

* Fixup.

* Fixing tied_Weights_keys on generic models.

* Fixup + T5 encoder/decoder tying (with different layers)

* Code quality.

* Dynamic member.

* trigger

* Fixing encoder name for other types of encoder/decoder combos.

* Fix scoping.

* Update .github/workflows/self-scheduled.yml
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fixing the tied_weights after the call.

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 15cd6871
...@@ -30,7 +30,7 @@ from contextlib import contextmanager ...@@ -30,7 +30,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, wraps from functools import partial, wraps
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile from zipfile import is_zipfile
import torch import torch
...@@ -573,6 +573,79 @@ def set_initialized_submodules(model, state_dict_keys): ...@@ -573,6 +573,79 @@ def set_initialized_submodules(model, state_dict_keys):
return not_initialized_submodules return not_initialized_submodules
def _end_ptr(tensor: torch.Tensor) -> int:
# extract the end of the pointer if the tensor is a slice of a bigger tensor
if tensor.nelement():
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
else:
stop = tensor.data_ptr()
return stop
def _get_tied_weight_keys(module: nn.Module, prefix=""):
tied_weight_keys = []
if getattr(module, "_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
tied_weight_keys.extend(names)
if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
tied_weight_keys.extend(names)
for name, submodule in module.named_children():
local_prefix = f"{prefix}.{name}" if prefix else name
tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
return tied_weight_keys
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]:
filtered_tensors = []
for shared in tensors:
if len(shared) < 2:
filtered_tensors.append(shared)
continue
areas = []
for name in shared:
tensor = state_dict[name]
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
areas.sort()
_, last_stop, last_name = areas[0]
filtered_tensors.append({last_name})
for start, stop, name in areas[1:]:
if start >= last_stop:
filtered_tensors.append({name})
else:
filtered_tensors[-1].add(name)
last_stop = stop
disjoint_tensors = []
shared_tensors = []
for tensors in filtered_tensors:
if len(tensors) == 1:
disjoint_tensors.append(tensors.pop())
else:
shared_tensors.append(tensors)
return shared_tensors, disjoint_tensors
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
shared_tensors = []
identical = []
for shared in tensors:
if len(shared) < 2:
continue
areas = collections.defaultdict(set)
for name in shared:
tensor = state_dict[name]
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
areas[area].add(name)
if len(areas) == 1:
identical.append(shared)
else:
shared_tensors.append(shared)
return shared_tensors, identical
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
...@@ -1646,15 +1719,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1646,15 +1719,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix): if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix) self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) tied_weights = self._tie_encoder_decoder_weights(
self.encoder, self.decoder, self.base_model_prefix, "encoder"
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights
for module in self.modules(): for module in self.modules():
if hasattr(module, "_tie_weights"): if hasattr(module, "_tie_weights"):
module._tie_weights() module._tie_weights()
@staticmethod @staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): def _tie_encoder_decoder_weights(
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
):
uninitialized_encoder_weights: List[str] = [] uninitialized_encoder_weights: List[str] = []
tied_weights: List[str] = []
if decoder.__class__ != encoder.__class__: if decoder.__class__ != encoder.__class__:
logger.info( logger.info(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
...@@ -1665,8 +1747,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1665,8 +1747,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
decoder_pointer: nn.Module, decoder_pointer: nn.Module,
encoder_pointer: nn.Module, encoder_pointer: nn.Module,
module_name: str, module_name: str,
base_encoder_name: str,
uninitialized_encoder_weights: List[str], uninitialized_encoder_weights: List[str],
depth=0, depth=0,
total_decoder_name="",
total_encoder_name="",
): ):
assert isinstance(decoder_pointer, nn.Module) and isinstance( assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module encoder_pointer, nn.Module
...@@ -1674,8 +1759,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1674,8 +1759,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(decoder_pointer, "weight"): if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight") assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight encoder_pointer.weight = decoder_pointer.weight
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
if hasattr(decoder_pointer, "bias"): if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias") assert hasattr(encoder_pointer, "bias")
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
encoder_pointer.bias = decoder_pointer.bias encoder_pointer.bias = decoder_pointer.bias
return return
...@@ -1713,19 +1800,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1713,19 +1800,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
decoder_modules[decoder_name], decoder_modules[decoder_name],
encoder_modules[encoder_name], encoder_modules[encoder_name],
module_name + "/" + name, module_name + "/" + name,
base_encoder_name,
uninitialized_encoder_weights, uninitialized_encoder_weights,
depth=depth + 1, depth=depth + 1,
total_encoder_name=f"{total_encoder_name}.{encoder_name}",
total_decoder_name=f"{total_decoder_name}.{decoder_name}",
) )
all_encoder_weights.remove(module_name + "/" + encoder_name) all_encoder_weights.remove(module_name + "/" + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights) uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively # tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) tie_encoder_to_decoder_recursively(
decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
)
if len(uninitialized_encoder_weights) > 0: if len(uninitialized_encoder_weights) > 0:
logger.warning( logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
) )
return tied_weights
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not""" """Tie or clone module weights depending of whether we are using TorchScript or not"""
...@@ -2402,34 +2496,49 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2402,34 +2496,49 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# These are all the pointers of shared tensors. # These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set() error_names = []
to_delete_names = set()
# Recursively descend to find tied weight keys
_tied_weights_keys = _get_tied_weight_keys(self)
for names in shared_ptrs.values(): for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on # Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent. # load. This allows to make sure the name which is kept is consistent.
if self._tied_weights_keys is not None: if _tied_weights_keys is not None:
found = 0 found = 0
for name in sorted(names): for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
if matches_pattern and name in state_dict: if matches_pattern and name in state_dict:
found += 1 found += 1
if found < len(names): if found < len(names):
del state_dict[name] to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning. # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get # If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown # the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with # during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning. # the config, better show a proper warning.
found = 0 shared_names, identical_names = _find_identical(shared_names, state_dict)
for name in names: # delete tensors that have identical storage
if name in state_dict: for inames in identical_names:
found += 1 known = inames.intersection(to_delete_names)
if found > 1: for name in known:
del state_dict[name] del state_dict[name]
warn_names.add(name) unknown = inames.difference(to_delete_names)
if len(warn_names) > 0: if len(unknown) > 1:
logger.warning_once( error_names.append(unknown)
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
if shared_names:
error_names.append(set(shared_names))
if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
) )
# Shard the model if it is too big. # Shard the model if it is too big.
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
import math import math
import os import os
import warnings import warnings
...@@ -1128,7 +1127,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -1128,7 +1127,7 @@ class BertForPreTraining(BertPreTrainedModel):
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
) )
class BertLMHeadModel(BertPreTrainedModel): class BertLMHeadModel(BertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -262,9 +262,16 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -262,9 +262,16 @@ class EncoderDecoderModel(PreTrainedModel):
if self.config.tie_encoder_decoder: if self.config.tie_encoder_decoder:
# tie encoder and decoder base model # tie encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights( tied_weights = self._tie_encoder_decoder_weights(
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix self.encoder,
self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
"encoder",
) )
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -1343,7 +1343,13 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1343,7 +1343,13 @@ class MarianMTModel(MarianPreTrainedModel):
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix): if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix) self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) tied_weights = self._tie_encoder_decoder_weights(
self.encoder, self.decoder, self.base_model_prefix, "encoder"
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights
for module in self.modules(): for module in self.modules():
if hasattr(module, "_tie_weights"): if hasattr(module, "_tie_weights"):
......
...@@ -1891,9 +1891,16 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -1891,9 +1891,16 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if self.config.tie_encoder_decoder: if self.config.tie_encoder_decoder:
# tie text encoder and decoder base model # tie text encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights( tied_weights = self._tie_encoder_decoder_weights(
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix self.text_encoder,
) self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
"text_encoder",
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights
def get_audio_encoder(self): def get_audio_encoder(self):
return self.audio_encoder return self.audio_encoder
......
...@@ -1810,9 +1810,16 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -1810,9 +1810,16 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
if self.config.tie_encoder_decoder: if self.config.tie_encoder_decoder:
# tie text encoder and decoder base model # tie text encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights( tied_weights = self._tie_encoder_decoder_weights(
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix self.text_encoder,
) self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
"text_encoder",
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights
def get_text_encoder(self): def get_text_encoder(self):
return self.text_encoder return self.text_encoder
......
...@@ -101,7 +101,7 @@ if is_torch_available(): ...@@ -101,7 +101,7 @@ if is_torch_available():
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask,
) )
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
# Fake pretrained models for tests # Fake pretrained models for tests
class BaseModel(PreTrainedModel): class BaseModel(PreTrainedModel):
...@@ -256,6 +256,26 @@ class ModelUtilsTest(TestCasePlus): ...@@ -256,6 +256,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(check_models_equal(model, model_loaded)) self.assertTrue(check_models_equal(model, model_loaded))
def test_model_manually_shared_disjointed_tensors_optimum(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
# Let's fuse qkv
attn = model.encoder.layer[0].attention.self
q = attn.query.weight
k = attn.key.weight
v = attn.value.weight
# Force some shared storage
qkv = torch.stack([q, k, v], dim=0)
attn.query.weight = torch.nn.Parameter(qkv[0])
attn.key.weight = torch.nn.Parameter(qkv[1])
attn.value.weight = torch.nn.Parameter(qkv[2])
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self): def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config) model = BertModel(config)
...@@ -2222,3 +2242,40 @@ class Mask4DTestHard(unittest.TestCase): ...@@ -2222,3 +2242,40 @@ class Mask4DTestHard(unittest.TestCase):
] ]
self.assertEqual(decoded_0, decoded_1b) self.assertEqual(decoded_0, decoded_1b)
@require_torch
class TestTensorSharing(TestCasePlus):
def test_disjoint(self):
main = torch.zeros(10)
a = main[:5]
b = main[5:]
state_dict = {"a": a, "b": b}
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [])
self.assertEqual(disjoint_names, ["a", "b"])
a = main[::2]
b = main[1::2]
state_dict = {"a": a, "b": b}
shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [{"a", "b"}])
self.assertEqual(disjoint_names, [])
def test_identical(self):
a = torch.zeros(10)
b = a
state_dict = {"a": a, "b": b}
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [])
self.assertEqual(identical_names, [{"a", "b"}])
b = a[:5]
state_dict = {"a": a, "b": b}
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
self.assertEqual(shared_names, [{"a", "b"}])
self.assertEqual(identical_names, [])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment