Unverified Commit 2454b98a authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Allow max shard size to be specified when saving pipeline (#9440)

allow max shard size to be specified when saving pipeline
parent 37e3603c
...@@ -189,6 +189,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -189,6 +189,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
safe_serialization: bool = True, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False, push_to_hub: bool = False,
**kwargs, **kwargs,
): ):
...@@ -204,6 +205,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -204,6 +205,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
This is to establish a common default size for this argument across different libraries in the Hugging
Face ecosystem (`transformers`, and `accelerate`, for example).
push_to_hub (`bool`, *optional*, defaults to `False`): push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
...@@ -278,12 +286,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -278,12 +286,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_signature = inspect.signature(save_method) save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_kwargs = {} save_kwargs = {}
if save_method_accept_safe: if save_method_accept_safe:
save_kwargs["safe_serialization"] = safe_serialization save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant: if save_method_accept_variant:
save_kwargs["variant"] = variant save_kwargs["variant"] = variant
if save_method_accept_max_shard_size:
save_kwargs["max_shard_size"] = max_shard_size
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment