Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b2add10d
Unverified
Commit
b2add10d
authored
Aug 19, 2024
by
Dhruv Nair
Committed by
GitHub
Aug 19, 2024
Browse files
Update `is_safetensors_compatible` check (#8991)
* update * update * update * update * update
parent
815d8822
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
151 additions
and
82 deletions
+151
-82
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+21
-26
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+2
-6
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+46
-25
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+82
-25
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
b2add10d
...
...
@@ -89,49 +89,44 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES
.
update
(
LOADABLE_CLASSES
[
library
])
def
is_safetensors_compatible
(
filenames
,
variant
=
None
,
passed_components
=
None
)
->
bool
:
def
is_safetensors_compatible
(
filenames
,
passed_components
=
None
)
->
bool
:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
- The model is safetensors compatible only if there is a safetensors file for each model component present in
filenames.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames
=
[]
sf_filenames
=
set
()
passed_components
=
passed_components
or
[]
# extract all components of the pipeline and their associated files
components
=
{}
for
filename
in
filenames
:
_
,
extension
=
os
.
path
.
splitext
(
filename
)
if
not
len
(
filename
.
split
(
"/"
))
==
2
:
continue
if
len
(
filename
.
split
(
"/"
))
==
2
and
filename
.
split
(
"/"
)[
0
]
in
passed_components
:
component
,
component_filename
=
filename
.
split
(
"/"
)
if
component
in
passed_components
:
continue
if
extension
==
".bin"
:
pt_filenames
.
append
(
os
.
path
.
normpath
(
filename
))
elif
extension
==
".safetensors"
:
sf_filenames
.
add
(
os
.
path
.
normpath
(
filename
))
components
.
setdefault
(
component
,
[])
components
[
component
].
append
(
component_filename
)
for
filename
in
pt_filenames
:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
path
,
filename
=
os
.
path
.
split
(
filename
)
filename
,
extension
=
os
.
path
.
splitext
(
filename
)
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for
component
,
component_filenames
in
components
.
items
():
matches
=
[]
for
component_filename
in
component_filenames
:
filename
,
extension
=
os
.
path
.
splitext
(
component_filename
)
if
filename
.
startswith
(
"pytorch_model"
):
filename
=
filename
.
replace
(
"pytorch_model"
,
"model"
)
else
:
filename
=
filename
match_exists
=
extension
==
".safetensors"
matches
.
append
(
match_exists
)
expected_sf_filename
=
os
.
path
.
normpath
(
os
.
path
.
join
(
path
,
filename
))
expected_sf_filename
=
f
"
{
expected_sf_filename
}
.safetensors"
if
expected_sf_filename
not
in
sf_filenames
:
logger
.
warning
(
f
"
{
expected_sf_filename
}
not found"
)
if
not
any
(
matches
):
return
False
return
True
...
...
src/diffusers/pipelines/pipeline_utils.py
View file @
b2add10d
...
...
@@ -1416,18 +1416,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if
(
use_safetensors
and
not
allow_pickle
and
not
is_safetensors_compatible
(
model_filenames
,
variant
=
variant
,
passed_components
=
passed_components
)
and
not
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
)
):
raise
EnvironmentError
(
f
"Could not find the necessary `safetensors` weights in
{
model_filenames
}
(variant=
{
variant
}
)"
)
if
from_flax
:
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
,
"*.onnx"
,
"*.pb"
]
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
variant
=
variant
,
passed_components
=
passed_components
):
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
):
ignore_patterns
=
[
"*.bin"
,
"*.msgpack"
]
use_onnx
=
use_onnx
if
use_onnx
is
not
None
else
pipeline_class
.
_is_onnx
...
...
tests/pipelines/test_pipeline_utils.py
View file @
b2add10d
...
...
@@ -68,25 +68,21 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_compatible_variant
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
def
test_diffusers_model_is_compatible_variant_mixed
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
"unet/diffusion_pytorch_model.
fp16.
safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_not_compatible_variant
(
self
):
filenames
=
[
...
...
@@ -99,25 +95,14 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin"
,
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_compatible_variant
(
self
):
filenames
=
[
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_not_compatible_variant
(
self
):
filenames
=
[
...
...
@@ -126,9 +111,45 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformers_is_compatible_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model-00001-of-00002.safetensors"
,
"text_encoder/model-00002-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformers_is_compatible_variant_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.fp16-00001-of-00002.safetensors"
,
"text_encoder/model.fp16-00001-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_sharded
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model-00001-of-00002.safetensors"
,
"unet/diffusion_pytorch_model-00002-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_variant_sharded
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors"
,
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_only_variants
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
tests/pipelines/test_pipelines.py
View file @
b2add10d
...
...
@@ -551,37 +551,94 @@ class DownloadTests(unittest.TestCase):
assert
sum
(
f
.
endswith
(
this_format
)
and
not
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)
for
f
in
files
)
==
3
assert
not
any
(
f
.
endswith
(
other_format
)
for
f
in
files
)
def
test_download_broken_variant
(
self
):
for
use_safetensors
in
[
False
,
True
]:
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for
variant
in
[
None
,
"no_ema"
]:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
def
test_download_safetensors_only_variant_exists_for_model
(
self
):
variant
=
None
use_safetensors
=
True
# text encoder is missing no variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
use_safetensors
=
use_safetensors
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
def
test_download_bin_only_variant_exists_for_model
(
self
):
variant
=
None
use_safetensors
=
False
# text encoder is missing Non-variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
use_safetensors
=
use_safetensors
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
# only unet has "no_ema" variant
def
test_download_safetensors_variant_does_not_exist_for_model
(
self
):
variant
=
"no_ema"
use_safetensors
=
True
# text encoder is missing no_ema variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
def
test_download_bin_variant_does_not_exist_for_model
(
self
):
variant
=
"no_ema"
use_safetensors
=
False
# text encoder is missing no_ema variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
def
test_local_save_load_index
(
self
):
prompt
=
"hello"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment