"vscode:/vscode.git/clone" did not exist on "955b2b97a69cf071c3517afaadf10bdf5ff77e1b"
Unverified Commit edadfc58 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Better default for offload_state_dict in from_pretrained (#18183)

parent aeeab1ff
...@@ -1687,9 +1687,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1687,9 +1687,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
GPU and the available CPU RAM if unset. GPU and the available CPU RAM if unset.
offload_folder (`str` or `os.PathLike`, *optional*): offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*, defaults to `False`): offload_state_dict (`bool`, *optional*):
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
...@@ -1775,7 +1776,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1775,7 +1776,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None) max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None) offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False) offload_state_dict = kwargs.pop("offload_state_dict", None)
if device_map is not None: if device_map is not None:
if low_cpu_mem_usage is None: if low_cpu_mem_usage is None:
...@@ -2168,7 +2169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2168,7 +2169,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
device_map=None, device_map=None,
offload_folder=None, offload_folder=None,
offload_state_dict=False, offload_state_dict=None,
dtype=None, dtype=None,
): ):
if device_map is not None and "disk" in device_map.values(): if device_map is not None and "disk" in device_map.values():
...@@ -2178,6 +2179,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2178,6 +2179,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" for them." " for them."
) )
os.makedirs(offload_folder, exist_ok=True) os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
# Retrieve missing & unexpected_keys # Retrieve missing & unexpected_keys
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys()) expected_keys = list(model_state_dict.keys())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment