Unverified Commit 6b548129 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Hub`] Add `safe_serialization` in push_to_hub (#24074)

add `safe_serialization` in push_to_hub
parent 6daf7c31
...@@ -740,6 +740,7 @@ class PushToHubMixin: ...@@ -740,6 +740,7 @@ class PushToHubMixin:
use_auth_token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "10GB", max_shard_size: Optional[Union[int, str]] = "10GB",
create_pr: bool = False, create_pr: bool = False,
safe_serialization: bool = False,
**deprecated_kwargs, **deprecated_kwargs,
) -> str: ) -> str:
""" """
...@@ -767,6 +768,8 @@ class PushToHubMixin: ...@@ -767,6 +768,8 @@ class PushToHubMixin:
by a unit (like `"5MB"`). by a unit (like `"5MB"`).
create_pr (`bool`, *optional*, defaults to `False`): create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit. Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether or not to convert the model weights in safetensors format for safer serialization.
Examples: Examples:
...@@ -809,7 +812,7 @@ class PushToHubMixin: ...@@ -809,7 +812,7 @@ class PushToHubMixin:
files_timestamps = self._get_files_timestamps(work_dir) files_timestamps = self._get_files_timestamps(work_dir)
# Save all files. # Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size) self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
return self._upload_modified_files( return self._upload_modified_files(
work_dir, work_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment