Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
2454b98a
Unverified
Commit
2454b98a
authored
Sep 16, 2024
by
Aryan
Committed by
GitHub
Sep 16, 2024
Browse files
Allow max shard size to be specified when saving pipeline (#9440)
allow max shard size to be specified when saving pipeline
parent
37e3603c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
0 deletions
+11
-0
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+11
-0
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
2454b98a
...
@@ -189,6 +189,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -189,6 +189,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_directory
:
Union
[
str
,
os
.
PathLike
],
save_directory
:
Union
[
str
,
os
.
PathLike
],
safe_serialization
:
bool
=
True
,
safe_serialization
:
bool
=
True
,
variant
:
Optional
[
str
]
=
None
,
variant
:
Optional
[
str
]
=
None
,
max_shard_size
:
Union
[
int
,
str
]
=
"10GB"
,
push_to_hub
:
bool
=
False
,
push_to_hub
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -204,6 +205,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -204,6 +205,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
This is to establish a common default size for this argument across different libraries in the Hugging
Face ecosystem (`transformers`, and `accelerate`, for example).
push_to_hub (`bool`, *optional*, defaults to `False`):
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
...
@@ -278,12 +286,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -278,12 +286,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_signature
=
inspect
.
signature
(
save_method
)
save_method_signature
=
inspect
.
signature
(
save_method
)
save_method_accept_safe
=
"safe_serialization"
in
save_method_signature
.
parameters
save_method_accept_safe
=
"safe_serialization"
in
save_method_signature
.
parameters
save_method_accept_variant
=
"variant"
in
save_method_signature
.
parameters
save_method_accept_variant
=
"variant"
in
save_method_signature
.
parameters
save_method_accept_max_shard_size
=
"max_shard_size"
in
save_method_signature
.
parameters
save_kwargs
=
{}
save_kwargs
=
{}
if
save_method_accept_safe
:
if
save_method_accept_safe
:
save_kwargs
[
"safe_serialization"
]
=
safe_serialization
save_kwargs
[
"safe_serialization"
]
=
safe_serialization
if
save_method_accept_variant
:
if
save_method_accept_variant
:
save_kwargs
[
"variant"
]
=
variant
save_kwargs
[
"variant"
]
=
variant
if
save_method_accept_max_shard_size
:
save_kwargs
[
"max_shard_size"
]
=
max_shard_size
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
),
**
save_kwargs
)
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
),
**
save_kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment