"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f1a7e0c0716094ee85f94fae36e16427121c27d0"
Unverified Commit 62832c96 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

1x model size CPU memory usage for `from_pretrained` (#13466)

* one possible solution

* low mem from_pretrained

* edge cases

* solve the persistent buffers

* style

* parametrize

* for later

* proper solution

* cleanup

* refactor; rework based on suggestions

* revert splitting into 2 parts, move checks into main func
parent ca257a06
...@@ -47,6 +47,7 @@ from .file_utils import ( ...@@ -47,6 +47,7 @@ from .file_utils import (
) )
from .generation_utils import GenerationMixin from .generation_utils import GenerationMixin
from .utils import logging from .utils import logging
from .utils.versions import require_version_core
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1139,6 +1140,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1139,6 +1140,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`): _fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
Whether or not to disable fast initialization. Whether or not to disable fast initialization.
low_cpu_mem_usage(:obj:`bool`, `optional`, defaults to `:obj:`False`):
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
This is an experimental feature and a subject to change at any moment.
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`): torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
dtype will be automatically derived from the model's weights. dtype will be automatically derived from the model's weights.
...@@ -1209,6 +1213,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1209,6 +1213,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True) _fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
from_pt = not (from_tf | from_flax) from_pt = not (from_tf | from_flax)
...@@ -1358,6 +1363,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1358,6 +1363,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
dtype_orig = cls._set_default_torch_dtype(torch_dtype) dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if low_cpu_mem_usage:
# save the keys
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
# Instantiate model. # Instantiate model.
...@@ -1407,6 +1417,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1407,6 +1417,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
raise raise
elif from_pt: elif from_pt:
if low_cpu_mem_usage:
cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
else:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
model, model,
state_dict, state_dict,
...@@ -1507,10 +1521,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1507,10 +1521,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if _fast_init: if _fast_init:
# retrieve unintialized modules and initialize # retrieve unintialized modules and initialize
unintialized_modules = model.retrieve_modules_from_names( uninitialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
) )
for module in unintialized_modules: for module in uninitialized_modules:
model._init_weights(module) model._init_weights(module)
# copy state_dict so _load_from_state_dict can modify it # copy state_dict so _load_from_state_dict can modify it
...@@ -1619,6 +1633,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1619,6 +1633,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules return retrieved_modules
@classmethod
def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Before it gets called we do:
1. save which state_dict keys we have
2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue:
3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
# a helper util to find the last sub-module and the param/buffer name
def find_submodule_and_param_name(model, long_key):
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
return submodule, split_key[0]
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
# only now can load state_dict
state_dict = torch.load(resolved_archive_file, map_location="cpu")
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
new_val = state_dict[k]
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
del state_dict
# To update the docstring, we need to copy the method, otherwise we change the original docstring. # To update the docstring, we need to copy the method, otherwise we change the original docstring.
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment