Unverified Commit 0e975e5f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Safetensors] Make sure metadata is saved (#2506)

* [Safetensors] Make sure metadata is saved

* make style
parent 7f43f652
...@@ -291,9 +291,6 @@ class ModelMixin(torch.nn.Module): ...@@ -291,9 +291,6 @@ class ModelMixin(torch.nn.Module):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
if save_function is None:
save_function = safetensors.torch.save_file if safe_serialization else torch.save
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
model_to_save = self model_to_save = self
...@@ -310,7 +307,12 @@ class ModelMixin(torch.nn.Module): ...@@ -310,7 +307,12 @@ class ModelMixin(torch.nn.Module):
weights_name = _add_variant(weights_name, variant) weights_name = _add_variant(weights_name, variant)
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, weights_name)) if safe_serialization:
safetensors.torch.save_file(
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(save_directory, weights_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment