Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
b2add10d
Unverified
Commit
b2add10d
authored
Aug 19, 2024
by
Dhruv Nair
Committed by
GitHub
Aug 19, 2024
Browse files
Update `is_safetensors_compatible` check (#8991)
* update * update * update * update * update
parent
815d8822
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
151 additions
and
82 deletions
+151
-82
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+21
-26
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+2
-6
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+46
-25
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+82
-25
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
b2add10d
...
@@ -89,49 +89,44 @@ for library in LOADABLE_CLASSES:
...
@@ -89,49 +89,44 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES
.
update
(
LOADABLE_CLASSES
[
library
])
ALL_IMPORTABLE_CLASSES
.
update
(
LOADABLE_CLASSES
[
library
])
def
is_safetensors_compatible
(
filenames
,
variant
=
None
,
passed_components
=
None
)
->
bool
:
def
is_safetensors_compatible
(
filenames
,
passed_components
=
None
)
->
bool
:
"""
"""
Checking for safetensors compatibility:
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
- The model is safetensors compatible only if there is a safetensors file for each model component present in
files to know which safetensors files are needed.
filenames.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
Converting default pytorch serialized filenames to safetensors serialized filenames:
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
extension is replaced with ".safetensors"
"""
"""
pt_filenames
=
[]
sf_filenames
=
set
()
passed_components
=
passed_components
or
[]
passed_components
=
passed_components
or
[]
# extract all components of the pipeline and their associated files
components
=
{}
for
filename
in
filenames
:
for
filename
in
filenames
:
_
,
extension
=
os
.
path
.
splitext
(
filename
)
if
not
len
(
filename
.
split
(
"/"
))
==
2
:
continue
if
len
(
filename
.
split
(
"/"
))
==
2
and
filename
.
split
(
"/"
)[
0
]
in
passed_components
:
component
,
component_filename
=
filename
.
split
(
"/"
)
if
component
in
passed_components
:
continue
continue
if
extension
==
".bin"
:
components
.
setdefault
(
component
,
[])
pt_filenames
.
append
(
os
.
path
.
normpath
(
filename
))
components
[
component
].
append
(
component_filename
)
elif
extension
==
".safetensors"
:
sf_filenames
.
add
(
os
.
path
.
normpath
(
filename
))
for
filename
in
pt_filenames
:
# iterate over all files of a component
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
# check if safetensor files exist for that component
path
,
filename
=
os
.
path
.
split
(
filename
)
# if variant is provided check if the variant of the safetensors exists
filename
,
extension
=
os
.
path
.
splitext
(
filename
)
for
component
,
component_filenames
in
components
.
items
():
matches
=
[]
for
component_filename
in
component_filenames
:
filename
,
extension
=
os
.
path
.
splitext
(
component_filename
)
if
filename
.
startswith
(
"pytorch_model"
):
match_exists
=
extension
==
".safetensors"
filename
=
filename
.
replace
(
"pytorch_model"
,
"model"
)
matches
.
append
(
match_exists
)
else
:
filename
=
filename
expected_sf_filename
=
os
.
path
.
normpath
(
os
.
path
.
join
(
path
,
filename
))
if
not
any
(
matches
):
expected_sf_filename
=
f
"
{
expected_sf_filename
}
.safetensors"
if
expected_sf_filename
not
in
sf_filenames
:
logger
.
warning
(
f
"
{
expected_sf_filename
}
not found"
)
return
False
return
False
return
True
return
True
...
...
src/diffusers/pipelines/pipeline_utils.py
View file @
b2add10d
...
@@ -1416,18 +1416,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
...
@@ -1416,18 +1416,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if
(
if
(
use_safetensors
use_safetensors
and
not
allow_pickle
and
not
allow_pickle
and
not
is_safetensors_compatible
(
and
not
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
)
model_filenames
,
variant
=
variant
,
passed_components
=
passed_components
)
):
):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"Could not find the necessary `safetensors` weights in
{
model_filenames
}
(variant=
{
variant
}
)"
f
"Could not find the necessary `safetensors` weights in
{
model_filenames
}
(variant=
{
variant
}
)"
)
)
if
from_flax
:
if
from_flax
:
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
,
"*.onnx"
,
"*.pb"
]
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
,
"*.onnx"
,
"*.pb"
]
elif
use_safetensors
and
is_safetensors_compatible
(
elif
use_safetensors
and
is_safetensors_compatible
(
model_filenames
,
passed_components
=
passed_components
):
model_filenames
,
variant
=
variant
,
passed_components
=
passed_components
):
ignore_patterns
=
[
"*.bin"
,
"*.msgpack"
]
ignore_patterns
=
[
"*.bin"
,
"*.msgpack"
]
use_onnx
=
use_onnx
if
use_onnx
is
not
None
else
pipeline_class
.
_is_onnx
use_onnx
=
use_onnx
if
use_onnx
is
not
None
else
pipeline_class
.
_is_onnx
...
...
tests/pipelines/test_pipeline_utils.py
View file @
b2add10d
...
@@ -68,25 +68,21 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
...
@@ -68,25 +68,21 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_compatible_variant
(
self
):
def
test_diffusers_model_is_compatible_variant
(
self
):
filenames
=
[
filenames
=
[
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_compatible_variant_partial
(
self
):
def
test_diffusers_model_is_compatible_variant_mixed
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
"unet/diffusion_pytorch_model.
fp16.
safetensors"
,
]
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_not_compatible_variant
(
self
):
def
test_diffusers_model_is_not_compatible_variant
(
self
):
filenames
=
[
filenames
=
[
...
@@ -99,25 +95,14 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
...
@@ -99,25 +95,14 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant
(
self
):
def
test_transformer_model_is_compatible_variant
(
self
):
filenames
=
[
filenames
=
[
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
"text_encoder/model.fp16.safetensors"
,
]
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_not_compatible_variant
(
self
):
def
test_transformer_model_is_not_compatible_variant
(
self
):
filenames
=
[
filenames
=
[
...
@@ -126,9 +111,45 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
...
@@ -126,9 +111,45 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/pytorch_model.fp16.bin"
,
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformers_is_compatible_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model-00001-of-00002.safetensors"
,
"text_encoder/model-00002-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformers_is_compatible_variant_sharded
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.fp16-00001-of-00002.safetensors"
,
"text_encoder/model.fp16-00001-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_sharded
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model-00001-of-00002.safetensors"
,
"unet/diffusion_pytorch_model-00002-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_variant_sharded
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors"
,
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_is_compatible_only_variants
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
tests/pipelines/test_pipelines.py
View file @
b2add10d
...
@@ -551,37 +551,94 @@ class DownloadTests(unittest.TestCase):
...
@@ -551,37 +551,94 @@ class DownloadTests(unittest.TestCase):
assert
sum
(
f
.
endswith
(
this_format
)
and
not
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)
for
f
in
files
)
==
3
assert
sum
(
f
.
endswith
(
this_format
)
and
not
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)
for
f
in
files
)
==
3
assert
not
any
(
f
.
endswith
(
other_format
)
for
f
in
files
)
assert
not
any
(
f
.
endswith
(
other_format
)
for
f
in
files
)
def
test_download_broken_variant
(
self
):
def
test_download_safetensors_only_variant_exists_for_model
(
self
):
for
use_safetensors
in
[
False
,
True
]:
variant
=
None
# text encoder is missing no variant and "no_ema" variant weights, so the following can't work
use_safetensors
=
True
for
variant
in
[
None
,
"no_ema"
]:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
# text encoder is missing no variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
with
self
.
assertRaises
(
OSError
)
as
error_context
:
"hf-internal-testing/stable-diffusion-broken-variants"
,
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
# text encoder has fp16 variants so we can load it
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
use_safetensors
=
use_safetensors
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
def
test_download_bin_only_variant_exists_for_model
(
self
):
variant
=
None
use_safetensors
=
False
# text encoder is missing Non-variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
# text encoder has fp16 variants so we can load it
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
use_safetensors
=
use_safetensors
,
cache_dir
=
tmpdirname
,
variant
=
"fp16"
,
)
all_root_files
=
[
t
[
-
1
]
for
t
in
os
.
walk
(
tmpdirname
)]
files
=
[
item
for
sublist
in
all_root_files
for
item
in
sublist
]
# None of the downloaded files should be a non-variant file even if we have some here:
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
# None of the downloaded files should be a non-variant file even if we have some here:
def
test_download_safetensors_variant_does_not_exist_for_model
(
self
):
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
variant
=
"no_ema"
assert
len
(
files
)
==
15
,
f
"We should only download 15 files, not
{
len
(
files
)
}
"
use_safetensors
=
True
# only unet has "no_ema" variant
# text encoder is missing no_ema variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
def
test_download_bin_variant_does_not_exist_for_model
(
self
):
variant
=
"no_ema"
use_safetensors
=
False
# text encoder is missing no_ema variant weights, so the following can't work
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/stable-diffusion-broken-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Error no file name"
in
str
(
error_context
.
exception
)
def
test_local_save_load_index
(
self
):
def
test_local_save_load_index
(
self
):
prompt
=
"hello"
prompt
=
"hello"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment