Unverified Commit 76b4f666 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Revert "[WIP] Hard error when ignoring tensors." (#28898)

Revert "[WIP] Hard error when ignoring tensors. (#27484)"

This reverts commit 2da28c4b.
parent 6529a5b5
...@@ -29,7 +29,7 @@ import warnings ...@@ -29,7 +29,7 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile from zipfile import is_zipfile
import torch import torch
...@@ -570,65 +570,6 @@ def set_initialized_submodules(model, state_dict_keys): ...@@ -570,65 +570,6 @@ def set_initialized_submodules(model, state_dict_keys):
return not_initialized_submodules return not_initialized_submodules
def _end_ptr(tensor: torch.Tensor) -> int:
# extract the end of the pointer if the tensor is a slice of a bigger tensor
if tensor.nelement():
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
else:
stop = tensor.data_ptr()
return stop
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
filtered_tensors = []
for shared in tensors:
if len(shared) < 2:
filtered_tensors.append(shared)
continue
areas = []
for name in shared:
tensor = state_dict[name]
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
areas.sort()
_, last_stop, last_name = areas[0]
filtered_tensors.append({last_name})
for start, stop, name in areas[1:]:
if start >= last_stop:
filtered_tensors.append({name})
else:
filtered_tensors[-1].add(name)
last_stop = stop
disjoint_tensors = []
shared_tensors = []
for tensors in filtered_tensors:
if len(tensors) == 1:
disjoint_tensors.append(tensors.pop())
else:
shared_tensors.append(tensors)
return shared_tensors, disjoint_tensors
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
shared_tensors = []
identical = []
for shared in tensors:
if len(shared) < 2:
continue
areas = collections.defaultdict(set)
for name in shared:
tensor = state_dict[name]
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
areas[area].add(name)
if len(areas) == 1:
identical.append(shared)
else:
shared_tensors.append(shared)
return shared_tensors, identical
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict # Convert old format to new format if needed from a PyTorch state_dict
old_keys = [] old_keys = []
...@@ -2441,8 +2382,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2441,8 +2382,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# These are all the pointers of shared tensors. # These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set() warn_names = set()
error_names = set()
to_delete_names = set()
for names in shared_ptrs.values(): for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on # Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent. # load. This allows to make sure the name which is kept is consistent.
...@@ -2453,42 +2392,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2453,42 +2392,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if matches_pattern and name in state_dict: if matches_pattern and name in state_dict:
found += 1 found += 1
if found < len(names): if found < len(names):
to_delete_names.add(name) del state_dict[name]
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# Those are actually tensor sharing but disjoint from each other, we can safely clone them # If the link between tensors was done at runtime then `from_pretrained` will not get
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way. # the key back leading to random tensor. A proper warning will be shown
for name in disjoint_names: # during reload (if applicable), but since the file is not necessarily compatible with
state_dict[name] = state_dict[name].clone() # the config, better show a proper warning.
found = 0
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning. for name in names:
# If the link between tensors was done at runtime then `from_pretrained` will not get if name in state_dict:
# the key back leading to random tensor. A proper warning will be shown found += 1
# during reload (if applicable), but since the file is not necessarily compatible with if found > 1:
# the config, better show a proper warning. del state_dict[name]
shared_names, identical_names = _find_identical(shared_names, state_dict) warn_names.add(name)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = sorted(inames.difference(to_delete_names))
for name in unknown[1:]:
del state_dict[name]
warn_names.add(name)
error_names.update(shared_names)
if len(warn_names) > 0: if len(warn_names) > 0:
logger.warning_once( logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
) )
if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
)
# Shard the model if it is too big. # Shard the model if it is too big.
if not _hf_peft_config_loaded: if not _hf_peft_config_loaded:
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
......
...@@ -257,26 +257,6 @@ class ModelUtilsTest(TestCasePlus): ...@@ -257,26 +257,6 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(check_models_equal(model, model_loaded)) self.assertTrue(check_models_equal(model, model_loaded))
def test_model_manually_shared_disjointed_tensors_optimum(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
# Let's fuse qkv
attn = model.encoder.layer[0].attention.self
q = attn.query.weight
k = attn.key.weight
v = attn.value.weight
# Force some shared storage
qkv = torch.stack([q, k, v], dim=0)
attn.query.weight = torch.nn.Parameter(qkv[0])
attn.key.weight = torch.nn.Parameter(qkv[1])
attn.value.weight = torch.nn.Parameter(qkv[2])
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self): def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config) model = BertModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment