Unverified Commit 30163921 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Safetensors offload (#20321)

* INtegrate safetensos in weight offloading

* Use safetensors checkpoint for offload when available

* Make naming consistent

* Make load faster

* Quality

* Add default
parent ac2f6674
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import gc import gc
import json import json
import os import os
...@@ -28,7 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -28,7 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
from torch import Tensor, device, nn from torch import Tensor, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
...@@ -545,6 +545,7 @@ def _load_state_dict_into_meta_model( ...@@ -545,6 +545,7 @@ def _load_state_dict_into_meta_model(
state_dict_index=None, state_dict_index=None,
dtype=None, dtype=None,
load_in_8bit=False, load_in_8bit=False,
is_safetensors=False,
): ):
""" """
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
...@@ -609,7 +610,8 @@ def _load_state_dict_into_meta_model( ...@@ -609,7 +610,8 @@ def _load_state_dict_into_meta_model(
raise ValueError(f"{param_name} doesn't have any device set.") raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name] param_device = device_map[module_name]
if param_device == "disk": if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index) if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None: elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit: elif not load_in_8bit:
...@@ -673,7 +675,7 @@ class ModuleUtilsMixin: ...@@ -673,7 +675,7 @@ class ModuleUtilsMixin:
module.mem_rss_pre_forward = 0 module.mem_rss_pre_forward = 0
@property @property
def device(self) -> device: def device(self) -> torch.device:
""" """
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device). device).
...@@ -2364,7 +2366,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2364,7 +2366,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if dtype_orig is not None: if dtype_orig is not None:
torch.set_default_dtype(dtype_orig) torch.set_default_dtype(dtype_orig)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( (
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model, model,
state_dict, state_dict,
loaded_state_dict_keys, # XXX: rename? loaded_state_dict_keys, # XXX: rename?
...@@ -2391,7 +2400,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2391,7 +2400,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Dispatch model with hooks on all devices if necessary # Dispatch model with hooks on all devices if necessary
if device_map is not None: if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder) dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
if output_loading_info: if output_loading_info:
if loading_info is None: if loading_info is None:
...@@ -2423,16 +2432,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2423,16 +2432,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=None, dtype=None,
load_in_8bit=False, load_in_8bit=False,
): ):
is_safetensors = False
if load_in_8bit: if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device from .utils.bitsandbytes import set_module_8bit_tensor_to_device
if device_map is not None and "disk" in device_map.values(): if device_map is not None and "disk" in device_map.values():
if offload_folder is None: archive_file = (
resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file
)
is_safetensors = archive_file.endswith(".safetensors")
if offload_folder is None and not is_safetensors:
raise ValueError( raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them." " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
) )
os.makedirs(offload_folder, exist_ok=True) if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None: if offload_state_dict is None:
offload_state_dict = True offload_state_dict = True
...@@ -2549,6 +2565,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2549,6 +2565,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del state_dict[checkpoint_key] del state_dict[checkpoint_key]
return mismatched_keys return mismatched_keys
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"])
str_dtype = str(dtype).replace("torch.", "")
offload_index = {
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype}
for p, f in sharded_metadata["weight_map"].items()
if param_device_map[p] == "disk"
}
if state_dict is not None: if state_dict is not None:
# Whole checkpoint # Whole checkpoint
mismatched_keys = _find_mismatched_keys( mismatched_keys = _find_mismatched_keys(
...@@ -2560,6 +2587,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2560,6 +2587,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
ignore_mismatched_sizes, ignore_mismatched_sizes,
) )
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
offload_index = None
else: else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True # Sharded checkpoint or whole but low_cpu_mem_usage==True
...@@ -2569,7 +2597,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2569,7 +2597,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs = [] error_msgs = []
mismatched_keys = [] mismatched_keys = []
offload_index = {} if device_map is not None and "disk" in device_map.values() else None if not is_safetensors:
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict: if offload_state_dict:
state_dict_folder = tempfile.mkdtemp() state_dict_folder = tempfile.mkdtemp()
state_dict_index = {} state_dict_index = {}
...@@ -2577,7 +2606,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2577,7 +2606,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_folder = None state_dict_folder = None
state_dict_index = None state_dict_index = None
if is_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
disk_only_shard_files = []
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
state_dict = load_state_dict(shard_file) state_dict = load_state_dict(shard_file)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
...@@ -2605,6 +2643,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2605,6 +2643,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_index=state_dict_index, state_dict_index=state_dict_index,
dtype=dtype, dtype=dtype,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
is_safetensors=is_safetensors,
) )
error_msgs += new_error_msgs error_msgs += new_error_msgs
else: else:
...@@ -2618,13 +2657,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2618,13 +2657,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if model != model_to_load: if model != model_to_load:
# We need to add the prefix of the base model # We need to add the prefix of the base model
prefix = cls.base_model_prefix prefix = cls.base_model_prefix
for weight_name in offload_index: if not is_safetensors:
shutil.move( for weight_name in offload_index:
os.path.join(offload_folder, f"{weight_name}.dat"), shutil.move(
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), os.path.join(offload_folder, f"{weight_name}.dat"),
) os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
save_offload_index(offload_index, offload_folder) if not is_safetensors:
save_offload_index(offload_index, offload_folder)
offload_index = None
if offload_state_dict: if offload_state_dict:
# Load back temporarily offloaded state dict # Load back temporarily offloaded state dict
...@@ -2678,7 +2720,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2678,7 +2720,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" to use it for predictions and inference." " to use it for predictions and inference."
) )
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names]) module_keys = set([".".join(key.split(".")[:-1]) for key in names])
...@@ -3191,3 +3233,26 @@ def unwrap_model(model: nn.Module) -> nn.Module: ...@@ -3191,3 +3233,26 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return unwrap_model(model.module) return unwrap_model(model.module)
else: else:
return model return model
def expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondance parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")})
return new_device_map
def get_disk_only_shard_files(device_map, sharded_metadata):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = collections.defaultdict(list)
for weight_name, filename in sharded_metadata["weight_map"].items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])
return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
...@@ -910,6 +910,7 @@ def get_checkpoint_shard_files( ...@@ -910,6 +910,7 @@ def get_checkpoint_shard_files(
shard_filenames = sorted(list(set(index["weight_map"].values()))) shard_filenames = sorted(list(set(index["weight_map"].values())))
sharded_metadata = index["metadata"] sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
# First, let's deal with local folder. # First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment