Unverified Commit ff689f57 authored by Benjamin Badger's avatar Benjamin Badger Committed by GitHub
Browse files

Extend save_pretrained to offloaded models (#27412)



* added hidden subset

* debugged hidden subset contrastive search

* added contrastive search compression

* debugged compressed contrastive search

* memory reduction for contrastive search

* debugged mem red

* added low memory option feature

* debugged mem optmimization output stack

* debugged mem optmimization output stack

* debugged low mem

* added low mem cache

* fixed 2047 tensor view

* debugged 2042 past key val inputs

* reformatted tensors

* changed low mem output

* final clean

* removed subset hidden csearch

* fixed hidden device

* fixed hidden device

* changed compressor dtype

* removed hstate compression

* integrated csearch in generate

* test csearch integration into generation

exit()

* fixed csearch kwarg integration with generation

* final wrap and added doc

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* added debug print

* direct hstate cat

* direct hstate cat

* direct hstate cat debug

* direct hstate cat debug

* expanded full hidden state stack

* expanded full hidden state stack

* matched dims for hstates

* matched dims for hstates

* logits fix

* equality test

* equality hidden debug

* debug

* added prints for debug

* added prints for debug

* equality check

* switched squeeze dim

* input format debug

* tracing top_k_ids

* removed trace

* added test context

* added jitter

* added jitter

* added jitter

* returned state

* rebuilt past key value reconstruction

* debugged

* cleaned traces

* added selection for pkv

* changed output to dict

* cleaned

* cleaned

* cleaned up contrastive search test

* moved low_memory kwarg

* debugged

* changed low mem test batch size to 1

* removed output

* debugged test input shape

* reformatted csearch test

* added trace

* removed unsqueeze on final forward pass

* replaced unsqueeze with view

* removed traces

* cleaned

* debugged model kwargs

* removed special models from test

* ran make quality

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* refactored

* refactored

* refactored

* make fixup

* renamed flag sequential

* renamed flag sequential

* iterative onloading

* black style and test utils

* added traces for integrated test

* debugged

* added traces

* make style

* removed traces, make style

* included suggestions and added test

* debugged test

* added offload module check and make style

* is_accelerate_available and make style

* added test decorator

* changed test model and config spec

* added offload condition

* added lazy loading for each shard

* debugged

* modified sharding

* debugged

* added traces

* removed safe serialization

* no index overload;

* trace on safe save ptrs

* added ptr condition

* debugged

* debugged ptr

* moved module map init

* remake shard only for offloaded modules

* refactored

* debugged

* refactored

* debugged

* cleaned and make style

* cleaned and make style

* added trace

* sparse module map

* debugged

* removed module map conditional

* refactored

* debug

* debugged

* added traces

* added shard mem trace

* added shard mem trace

* removed underlying storage check

* refactored

* memory leak removal and make style

* cleaned

* swapped test decs and make style

* added mem checks and make style

* added free mem warning

* implemented some suggestions

* moved onloading to accelerate

* refactored for accelerate integration

* cleaned test

* make style

* debugged offload map name

* cleaned and make style

* replaced meta device check for sharding

* cleaned and make style

* implemented some suggestions

* more suggestions

* update warning
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* more suggestions

* make style

* new make style

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 8bcf9c8d
...@@ -119,6 +119,10 @@ if is_accelerate_available(): ...@@ -119,6 +119,10 @@ if is_accelerate_available():
set_module_tensor_to_device, set_module_tensor_to_device,
) )
accelerate_version = version.parse(importlib.metadata.version("accelerate"))
if accelerate_version >= version.parse("0.31"):
from accelerate.utils.modeling import get_state_dict_from_offload
if is_safetensors_available(): if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
...@@ -374,13 +378,12 @@ def shard_checkpoint( ...@@ -374,13 +378,12 @@ def shard_checkpoint(
storage_id = id_tensor_storage(weight) storage_id = id_tensor_storage(weight)
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block: if storage_id in storage_id_to_block and weight.device != torch.device("meta"):
block_id = storage_id_to_block[storage_id] block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight sharded_state_dicts[block_id][key] = weight
continue continue
weight_size = weight.numel() * dtype_byte_size(weight.dtype) weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard. # weight in the current shard.
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
...@@ -2504,8 +2507,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2504,8 +2507,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
current_peft_config = self.peft_config[active_adapter] current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory) current_peft_config.save_pretrained(save_directory)
# for offloaded modules
module_map = {}
# Save the model # Save the model
if state_dict is None: if state_dict is None:
# if any model parameters are offloaded to the disk, make module map
if hasattr(self, "hf_device_map") and (
"cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()
):
warnings.warn(
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
)
for name, module in model_to_save.named_modules():
if name == "":
continue
module_state_dict = module.state_dict()
for key in module_state_dict:
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
# Translate state_dict from smp to hf if saving with smp >= 1.10 # Translate state_dict from smp to hf if saving with smp >= 1.10
...@@ -2531,12 +2552,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2531,12 +2552,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# In the non-tensor case, fall back to the pointer of the object itself # In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name) ptrs[id(tensor)].append(name)
# These are all the pointers of shared tensors. # These are all the pointers of shared tensors
if hasattr(self, "hf_device_map"):
# if the model has offloaded parameters, we must check using find_tied_parameters()
tied_params = find_tied_parameters(self)
if tied_params:
tied_names = tied_params[0]
shared_ptrs = {
ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
}
else:
shared_ptrs = {}
else:
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
error_names = []
to_delete_names = set()
# Recursively descend to find tied weight keys # Recursively descend to find tied weight keys
_tied_weights_keys = _get_tied_weight_keys(self) _tied_weights_keys = _get_tied_weight_keys(self)
error_names = []
to_delete_names = set()
for names in shared_ptrs.values(): for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on # Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent. # load. This allows to make sure the name which is kept is consistent.
...@@ -2609,6 +2642,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2609,6 +2642,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the model # Save the model
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
# remake shard with onloaded parameters if necessary
if module_map:
if accelerate_version < version.parse("0.31"):
raise ImportError(
f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
f"Please upgrade accelerate with `pip install -U accelerate`"
)
# init state_dict for this shard
state_dict = {name: "" for name in shard}
for module_name in shard:
module = module_map[module_name]
# update state dict with onloaded parameters
state_dict = get_state_dict_from_offload(module, module_name, state_dict)
# assign shard to be the completed state dict
shard = state_dict
del state_dict
gc.collect()
if safe_serialization: if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed # At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough. # joyfulness), but for now this enough.
......
...@@ -1056,6 +1056,43 @@ class ModelUtilsTest(TestCasePlus): ...@@ -1056,6 +1056,43 @@ class ModelUtilsTest(TestCasePlus):
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
def test_save_offloaded_model(self):
device_map = {
"transformer.wte": f"{torch_device}:0",
"transformer.wpe": f"{torch_device}:0",
"transformer.h.0": "cpu",
"transformer.h.1": "cpu",
"transformer.h.2": "cpu",
"transformer.h.3": "disk",
"transformer.h.4": "disk",
"transformer.ln_f": f"{torch_device}:0",
"lm_head": f"{torch_device}:0",
}
# check_models_equal requires onloaded tensors
model_id = "hf-internal-testing/tiny-random-gpt2"
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
cpu_output = onloaded_model(inputs)[0]
with tempfile.TemporaryDirectory() as tmp_dir:
offload_folder = os.path.join(tmp_dir, "offload")
offloaded_model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=device_map, offload_folder=offload_folder
)
presaved_output = offloaded_model(inputs)[0]
offloaded_model.save_pretrained(
tmp_dir, max_shard_size="200KB"
) # model is 1.6MB, max shard size is allocated to cpu by default
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
postsaved_output = saved_model(inputs)[0]
self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4))
self.assertTrue(torch.allclose(presaved_output, postsaved_output))
@require_safetensors @require_safetensors
def test_use_safetensors(self): def test_use_safetensors(self):
# Should not raise anymore # Should not raise anymore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment