Unverified Commit edadfc58 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Better default for offload_state_dict in from_pretrained (#18183)

parent aeeab1ff
......@@ -1687,9 +1687,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*, defaults to `False`):
offload_state_dict (`bool`, *optional*):
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit.
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
......@@ -1775,7 +1776,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
offload_state_dict = kwargs.pop("offload_state_dict", None)
if device_map is not None:
if low_cpu_mem_usage is None:
......@@ -2168,7 +2169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
low_cpu_mem_usage=False,
device_map=None,
offload_folder=None,
offload_state_dict=False,
offload_state_dict=None,
dtype=None,
):
if device_map is not None and "disk" in device_map.values():
......@@ -2178,6 +2179,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" for them."
)
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment