Unverified Commit ab9fe452 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix disk offload for full safetensors checkpoints (#20497)

parent 4aa630ee
......@@ -597,6 +597,9 @@ def _load_state_dict_into_meta_model(
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)
if device_map is None:
param_device = "cpu"
......@@ -2452,6 +2455,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if offload_state_dict is None:
offload_state_dict = True
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
......@@ -2567,12 +2571,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"])
param_device_map = expand_device_map(device_map, original_loaded_keys)
str_dtype = str(dtype).replace("torch.", "")
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
resolved_archive_file[0]
if isinstance(resolved_archive_file, (list, tuple))
else resolved_archive_file
)
weight_map = {p: archive_file for p in original_loaded_keys}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = {
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype}
for p, f in sharded_metadata["weight_map"].items()
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if param_device_map[p] == "disk"
}
......@@ -2606,7 +2619,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_folder = None
state_dict_index = None
if is_safetensors:
if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment