"docs/vscode:/vscode.git/clone" did not exist on "3f4e79d29ce32d9f8f75b082836b01ee180d0966"
Unverified Commit 62832c96 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

1x model size CPU memory usage for `from_pretrained` (#13466)

* one possible solution

* low mem from_pretrained

* edge cases

* solve the persistent buffers

* style

* parametrize

* for later

* proper solution

* cleanup

* refactor; rework based on suggestions

* revert splitting into 2 parts, move checks into main func
parent ca257a06
......@@ -47,6 +47,7 @@ from .file_utils import (
)
from .generation_utils import GenerationMixin
from .utils import logging
from .utils.versions import require_version_core
logger = logging.get_logger(__name__)
......@@ -1139,6 +1140,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Please refer to the mirror site for more information.
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
Whether or not to disable fast initialization.
low_cpu_mem_usage(:obj:`bool`, `optional`, defaults to `:obj:`False`):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
dtype will be automatically derived from the model's weights.
......@@ -1209,6 +1213,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
from_pt = not (from_tf | from_flax)
......@@ -1358,6 +1363,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if low_cpu_mem_usage:
# save the keys
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
......@@ -1407,6 +1417,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
raise
elif from_pt:
if low_cpu_mem_usage:
cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
else:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
model,
state_dict,
......@@ -1507,10 +1521,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if _fast_init:
# retrieve unintialized modules and initialize
unintialized_modules = model.retrieve_modules_from_names(
uninitialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
for module in uninitialized_modules:
model._init_weights(module)
# copy state_dict so _load_from_state_dict can modify it
......@@ -1619,6 +1633,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules
@classmethod
def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before it gets called we do:
1. save which state_dict keys we have
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
# a helper util to find the last sub-module and the param/buffer name
def find_submodule_and_param_name(model, long_key):
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
return submodule, split_key[0]
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
# only now can load state_dict
state_dict = torch.load(resolved_archive_file, map_location="cpu")
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
new_val = state_dict[k]
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
del state_dict
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment