Unverified Commit da18fbd5 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

set max_shard_size to None for pipeline save_pretrained (#9447)



* update default max_shard_size

* add None check to fix tests

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent ba06124e
...@@ -189,7 +189,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -189,7 +189,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
safe_serialization: bool = True, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB", max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False, push_to_hub: bool = False,
**kwargs, **kwargs,
): ):
...@@ -205,7 +205,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -205,7 +205,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"10GB"`): max_shard_size (`int` or `str`, defaults to `None`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
...@@ -293,7 +293,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -293,7 +293,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_kwargs["safe_serialization"] = safe_serialization save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant: if save_method_accept_variant:
save_kwargs["variant"] = variant save_kwargs["variant"] = variant
if save_method_accept_max_shard_size: if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size save_kwargs["max_shard_size"] = max_shard_size
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment