Unverified Commit 563a8d58 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Delete `state_dict` to release memory as early as possible (#18832)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a26c7523
......@@ -417,7 +417,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
......@@ -434,9 +434,12 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(child, state_dict, prefix + name + ".")
load(model_to_load, prefix=start_prefix)
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment