Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
b2add10d
Unverified
Commit
b2add10d
authored
Aug 19, 2024
by
Dhruv Nair
Committed by
GitHub
Aug 19, 2024
Browse files
Update `is_safetensors_compatible` check (#8991)
* update * update * update * update * update
parent
815d8822
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
151 additions
and
82 deletions
+151
-82
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+21
-26
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+2
-6
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+46
-25
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+82
-25
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
b2add10d
...
...
@@ -89,49 +89,44 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES
.
update
(
LOADABLE_CLASSES
[
library
])
def
is_safetensors_compatible
(
filenames
,
variant
=
None
,
passed_components
=
None
)
->
bool
:
def
is_safetensors_compatible
(
filenames
,
passed_components
=
None
)
->
bool
:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
- The model is safetensors compatible only if there is a safetensors file for each model component present in
filenames.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames
=
[]
sf_filenames
=
set
()
passed_components
=
passed_components
or
[]
# extract all components of the pipeline and their associated files
components
=
{}
for
filename
in
filenames
:
_
,
extension
=
os
.
path
.
splitext
(
filename
)
if
not
len
(
filename
.
split
(
"/"
))
==
2
:
continue
if
len
(
filename
.
split
(
"/"
))
==
2
and
filename
.
split
(
"/"
)[
0
]
in
passed_components
:
component
,
component_filename
=
filename
.
split
(
"/"
)
if
component
in
passed_components
:
continue
if
extension
==
".bin"
:
pt_filenames
.
append
(
os
.
path
.
normpath
(
filename
))
elif
extension
==
".safetensors"
:
sf_filenames
.
add
(
os
.
path
.
normpath
(
filename
))
components
.
setdefault
(
component
,
[])
components
[
component
].
append
(
component_filename
)
for
filename
in
pt_filenames
:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
path
,
filename
=
os
.
path
.
split
(
filename
)
filename
,
extension
=
os
.
path
.
splitext
(
filename
)
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for
component
,
component_filenames
in
components
.
items
():
matches
=
[]
for
component_filename
in
component_filenames
:
filename
,
extension
=
os
.
path
.
splitext
(
component_filename
)
if
filename
.
startswith
(
"pytorch_model"
):
filename
=
filename
.
replace
(
"pytorch_model"
,
"model"
)
else
:
filename
=
filename
match_exists
=
extension
==
".safetensors"
matches
.
append
(
match_exists
)
expected_sf_filename
=
os
.
path
.
normpath
(
os
.
path
.
join
(
path
,
filename
))
expected_sf_filename
=
f
"
{
expected_sf_filename
}
.safetensors"
if
expected_sf_filename
not
in
sf_filenames
:
logger
.
warning
(
f
"
{
expected_sf_filename
}
not found"
)
if
not
any
(
matches
):
return
False
return
True
...
...
src/diffusers/pipelines/pipeline_utils.py
View file @
b2add10d
...
...
@@ -1416,18 +1416,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if
(
use_safetensors
and
not
allow_pickle
and
not
is_safetensors_compatible
(
model_filenames
,
variant
=
variant
,
passed_components
=
passed_components
)
and
not
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
)
):
raise
EnvironmentError
(
f
"Could not find the necessary `safetensors` weights in
{
model_filenames
}
(variant=
{
variant
}
)"
)
if
from_flax
:
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
,
"*.onnx"
,
"*.pb"
]
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
variant
=
variant
,
passed_components
=
passed_components
):
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
):
ignore_patterns
=
[
"*.bin"
,
"*.msgpack"
]
use_onnx
=
use_onnx
if
use_onnx
is
not
None
else
pipeline_class
.
_is_onnx
...
...
tests/pipelines/test_pipeline_utils.py
View file @
b2add10d
...
...
@@ -68,25 +68,21 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_compatible_variant
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
def
test_diffusers_model_is_compatible_variant_mixed
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
"unet/diffusion_pytorch_model.
fp16.
safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_not_compatible_variant
(
self
):
filenames
=
[
...
...
@@ -99,25 +95,14 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin"
,
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_compatible_variant
(
self
):
filenames
=
[
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_not_compatible_variant
(
self
):
filenames
=
[
...
...
@@ -126,9 +111,45 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformers_is_compatible_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model-00001-of-00002.safetensors"
,
"text_encoder/model-00002-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformers_is_compatible_variant_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.fp16-00001-of-00002.safetensors"
,
"text_encoder/model.fp16-00001-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_sharded
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model-00001-of-00002.safetensors"
,
"unet/diffusion_pytorch_model-00002-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_variant_sharded
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors"
,
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_only_variants
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
tests/pipelines/test_pipelines.py
View file @
b2add10d
...
...
@@ -551,37 +551,94 @@ class DownloadTests(unittest.TestCase):
assert
sum
(
f
.
endswith
(
this_format
)
and
not
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)
for
f
in
files
)
==
3
assert
not
any
(
f
.
endswith
(
other_format
)
for
f
in
files
)
def
test_download_broken_variant
(
self
):
for
use_safetensors
in
[
False
,
True
]:
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
for
variant
in
[
None
,
"no_ema"
]:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
def
test_download_safetensors_only_variant_exists_for_model
(
self
):
variant
=
None
use_safetensors
=
True
# text encoder is missing no variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
use_safetensors
=
use_safetensors
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
def
test_download_bin_only_variant_exists_for_model
(
self
):
variant
=
None
use_safetensors
=
False
# text encoder is missing Non-variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
use_safetensors
=
use_safetensors
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
# only unet has "no_ema" variant
def
test_download_safetensors_variant_does_not_exist_for_model
(
self
):
variant
=
"no_ema"
use_safetensors
=
True
# text encoder is missing no_ema variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
def
test_download_bin_variant_does_not_exist_for_model
(
self
):
variant
=
"no_ema"
use_safetensors
=
False
# text encoder is missing no_ema variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
def
test_local_save_load_index
(
self
):
prompt
=
"hello"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment