Unverified Commit f354dd9e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Save Pretrained] Remove dead code lines that can accidentally remove pytorch files (#2038)

correct safetensors
parent 007c914c
...@@ -272,15 +272,6 @@ class ModelMixin(torch.nn.Module): ...@@ -272,15 +272,6 @@ class ModelMixin(torch.nn.Module):
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, weights_name)) save_function(state_dict, os.path.join(save_directory, weights_name))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment