Unverified Commit ab9fe452 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix disk offload for full safetensors checkpoints (#20497)

parent 4aa630ee
...@@ -597,6 +597,9 @@ def _load_state_dict_into_meta_model( ...@@ -597,6 +597,9 @@ def _load_state_dict_into_meta_model(
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param): if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype) param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
...@@ -2452,6 +2455,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2452,6 +2455,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if offload_state_dict is None: if offload_state_dict is None:
offload_state_dict = True offload_state_dict = True
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
# Retrieve missing & unexpected_keys # Retrieve missing & unexpected_keys
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys()) expected_keys = list(model_state_dict.keys())
...@@ -2567,12 +2571,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2567,12 +2571,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
if device_map is not None and is_safetensors: if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"]) param_device_map = expand_device_map(device_map, original_loaded_keys)
str_dtype = str(dtype).replace("torch.", "") str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
resolved_archive_file[0]
if isinstance(resolved_archive_file, (list, tuple))
else resolved_archive_file
)
weight_map = {p: archive_file for p in original_loaded_keys}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = { offload_index = {
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype} p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in sharded_metadata["weight_map"].items() for p, f in weight_map.items()
if param_device_map[p] == "disk" if param_device_map[p] == "disk"
} }
...@@ -2606,7 +2619,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2606,7 +2619,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_folder = None state_dict_folder = None
state_dict_index = None state_dict_index = None
if is_safetensors: if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata) disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment