Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
856dad57
Unverified
Commit
856dad57
authored
Feb 28, 2023
by
Will Berman
Committed by
GitHub
Feb 28, 2023
Browse files
is_safetensors_compatible refactor (#2499)
* is_safetensors_compatible refactor * files list comma
parent
a75ac3fa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
176 additions
and
14 deletions
+176
-14
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+42
-14
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+134
-0
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
856dad57
...
...
@@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):
def
is_safetensors_compatible
(
filenames
,
variant
=
None
)
->
bool
:
pt_filenames
=
set
(
filename
for
filename
in
filenames
if
filename
.
endswith
(
".bin"
))
is_safetensors_compatible
=
any
(
file
.
endswith
(
".safetensors"
)
for
file
in
filenames
)
for
pt_filename
in
pt_filenames
:
_variant
=
f
".
{
variant
}
"
if
(
variant
is
not
None
and
variant
in
pt_filename
)
else
""
prefix
,
raw
=
os
.
path
.
split
(
pt_filename
)
if
raw
==
f
"pytorch_model
{
_variant
}
.bin"
:
# transformers specific
sf_filename
=
os
.
path
.
join
(
prefix
,
f
"model
{
_variant
}
.safetensors"
)
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames
=
[]
sf_filenames
=
set
()
for
filename
in
filenames
:
_
,
extension
=
os
.
path
.
splitext
(
filename
)
if
extension
==
".bin"
:
pt_filenames
.
append
(
filename
)
elif
extension
==
".safetensors"
:
sf_filenames
.
add
(
filename
)
for
filename
in
pt_filenames
:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
path
,
filename
=
os
.
path
.
split
(
filename
)
filename
,
extension
=
os
.
path
.
splitext
(
filename
)
if
filename
==
"pytorch_model"
:
filename
=
"model"
elif
filename
==
f
"pytorch_model.
{
variant
}
"
:
filename
=
f
"model.
{
variant
}
"
else
:
sf_filename
=
pt_filename
[:
-
len
(
".bin"
)]
+
".safetensors"
if
is_safetensors_compatible
and
sf_filename
not
in
filenames
:
logger
.
warning
(
f
"
{
sf_filename
}
not found"
)
is_safetensors_compatible
=
False
return
is_safetensors_compatible
filename
=
filename
expected_sf_filename
=
os
.
path
.
join
(
path
,
filename
)
expected_sf_filename
=
f
"
{
expected_sf_filename
}
.safetensors"
if
expected_sf_filename
not
in
sf_filenames
:
logger
.
warning
(
f
"
{
expected_sf_filename
}
not found"
)
return
False
return
True
def
variant_compatible_siblings
(
info
,
variant
=
None
)
->
Union
[
List
[
os
.
PathLike
],
str
]:
...
...
tests/pipelines/test_pipeline_utils.py
0 → 100644
View file @
856dad57
import
unittest
from
diffusers.pipelines.pipeline_utils
import
is_safetensors_compatible
class
IsSafetensorsCompatibleTests
(
unittest
.
TestCase
):
def
test_all_is_compatible
(
self
):
filenames
=
[
"safety_checker/pytorch_model.bin"
,
"safety_checker/model.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
"vae/diffusion_pytorch_model.safetensors"
,
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_compatible
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_not_compatible
(
self
):
filenames
=
[
"safety_checker/pytorch_model.bin"
,
"safety_checker/model.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
"vae/diffusion_pytorch_model.safetensors"
,
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
"unet/diffusion_pytorch_model.bin"
,
# Removed: 'unet/diffusion_pytorch_model.safetensors',
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_compatible
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_not_compatible
(
self
):
filenames
=
[
"safety_checker/pytorch_model.bin"
,
"safety_checker/model.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
"vae/diffusion_pytorch_model.safetensors"
,
"text_encoder/pytorch_model.bin"
,
# Removed: 'text_encoder/model.safetensors',
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_all_is_compatible_variant
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_compatible_variant
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_not_compatible_variant
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant
(
self
):
filenames
=
[
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_not_compatible_variant
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment