Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
214372aa
Unverified
Commit
214372aa
authored
Aug 21, 2024
by
YiYi Xu
Committed by
GitHub
Aug 21, 2024
Browse files
fix a regression in `is_safetensors_compatible` (#9234)
fix
parent
867e0c91
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
3 deletions
+33
-3
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+3
-1
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+6
-2
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+24
-0
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
214372aa
...
...
@@ -89,7 +89,7 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES
.
update
(
LOADABLE_CLASSES
[
library
])
def
is_safetensors_compatible
(
filenames
,
passed_components
=
None
)
->
bool
:
def
is_safetensors_compatible
(
filenames
,
passed_components
=
None
,
folder_names
=
None
)
->
bool
:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
...
...
@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
extension is replaced with ".safetensors"
"""
passed_components
=
passed_components
or
[]
if
folder_names
is
not
None
:
filenames
=
{
f
for
f
in
filenames
if
os
.
path
.
split
(
f
)[
0
]
in
folder_names
}
# extract all components of the pipeline and their associated files
components
=
{}
...
...
src/diffusers/pipelines/pipeline_utils.py
View file @
214372aa
...
...
@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if
(
use_safetensors
and
not
allow_pickle
and
not
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
)
and
not
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
,
folder_names
=
model_folder_names
)
):
raise
EnvironmentError
(
f
"Could not find the necessary `safetensors` weights in
{
model_filenames
}
(variant=
{
variant
}
)"
)
if
from_flax
:
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
,
"*.onnx"
,
"*.pb"
]
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
):
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
,
folder_names
=
model_folder_names
):
ignore_patterns
=
[
"*.bin"
,
"*.msgpack"
]
use_onnx
=
use_onnx
if
use_onnx
is
not
None
else
pipeline_class
.
_is_onnx
...
...
tests/pipelines/test_pipeline_utils.py
View file @
214372aa
...
...
@@ -116,6 +116,30 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_compatible_variant_extra_folder
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
folder_names
=
{
"vae"
,
"unet"
}))
def
test_transformer_model_is_not_compatible_variant_extra_folder
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
folder_names
=
{
"text_encoder"
}))
def
test_transformers_is_compatible_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment