Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
856dad57
Unverified
Commit
856dad57
authored
Feb 28, 2023
by
Will Berman
Committed by
GitHub
Feb 28, 2023
Browse files
is_safetensors_compatible refactor (#2499)
* is_safetensors_compatible refactor * files list comma
parent
a75ac3fa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
176 additions
and
14 deletions
+176
-14
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+42
-14
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+134
-0
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
856dad57
...
...
@@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):
def
is_safetensors_compatible
(
filenames
,
variant
=
None
)
->
bool
:
pt_filenames
=
set
(
filename
for
filename
in
filenames
if
filename
.
endswith
(
".bin"
))
is_safetensors_compatible
=
any
(
file
.
endswith
(
".safetensors"
)
for
file
in
filenames
)
for
pt_filename
in
pt_filenames
:
_variant
=
f
".
{
variant
}
"
if
(
variant
is
not
None
and
variant
in
pt_filename
)
else
""
prefix
,
raw
=
os
.
path
.
split
(
pt_filename
)
if
raw
==
f
"pytorch_model
{
_variant
}
.bin"
:
# transformers specific
sf_filename
=
os
.
path
.
join
(
prefix
,
f
"model
{
_variant
}
.safetensors"
)
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames
=
[]
sf_filenames
=
set
()
for
filename
in
filenames
:
_
,
extension
=
os
.
path
.
splitext
(
filename
)
if
extension
==
".bin"
:
pt_filenames
.
append
(
filename
)
elif
extension
==
".safetensors"
:
sf_filenames
.
add
(
filename
)
for
filename
in
pt_filenames
:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
path
,
filename
=
os
.
path
.
split
(
filename
)
filename
,
extension
=
os
.
path
.
splitext
(
filename
)
if
filename
==
"pytorch_model"
:
filename
=
"model"
elif
filename
==
f
"pytorch_model.
{
variant
}
"
:
filename
=
f
"model.
{
variant
}
"
else
:
sf_filename
=
pt_filename
[:
-
len
(
".bin"
)]
+
".safetensors"
if
is_safetensors_compatible
and
sf_filename
not
in
filenames
:
logger
.
warning
(
f
"
{
sf_filename
}
not found"
)
is_safetensors_compatible
=
False
return
is_safetensors_compatible
filename
=
filename
expected_sf_filename
=
os
.
path
.
join
(
path
,
filename
)
expected_sf_filename
=
f
"
{
expected_sf_filename
}
.safetensors"
if
expected_sf_filename
not
in
sf_filenames
:
logger
.
warning
(
f
"
{
expected_sf_filename
}
not found"
)
return
False
return
True
def
variant_compatible_siblings
(
info
,
variant
=
None
)
->
Union
[
List
[
os
.
PathLike
],
str
]:
...
...
tests/pipelines/test_pipeline_utils.py
0 → 100644
View file @
856dad57
import
unittest
from
diffusers.pipelines.pipeline_utils
import
is_safetensors_compatible
class
IsSafetensorsCompatibleTests
(
unittest
.
TestCase
):
def
test_all_is_compatible
(
self
):
filenames
=
[
"safety_checker/pytorch_model.bin"
,
"safety_checker/model.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
"vae/diffusion_pytorch_model.safetensors"
,
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_compatible
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_diffusers_model_is_not_compatible
(
self
):
filenames
=
[
"safety_checker/pytorch_model.bin"
,
"safety_checker/model.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
"vae/diffusion_pytorch_model.safetensors"
,
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
"unet/diffusion_pytorch_model.bin"
,
# Removed: 'unet/diffusion_pytorch_model.safetensors',
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_compatible
(
self
):
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
))
def
test_transformer_model_is_not_compatible
(
self
):
filenames
=
[
"safety_checker/pytorch_model.bin"
,
"safety_checker/model.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
"vae/diffusion_pytorch_model.safetensors"
,
"text_encoder/pytorch_model.bin"
,
# Removed: 'text_encoder/model.safetensors',
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_all_is_compatible_variant
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_compatible_variant
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"unet/diffusion_pytorch_model.bin"
,
"unet/diffusion_pytorch_model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_diffusers_model_is_not_compatible_variant
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
"unet/diffusion_pytorch_model.fp16.bin"
,
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant
(
self
):
filenames
=
[
"text_encoder/pytorch_model.fp16.bin"
,
"text_encoder/model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_compatible_variant_partial
(
self
):
# pass variant but use the non-variant filenames
filenames
=
[
"text_encoder/pytorch_model.bin"
,
"text_encoder/model.safetensors"
,
]
variant
=
"fp16"
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
def
test_transformer_model_is_not_compatible_variant
(
self
):
filenames
=
[
"safety_checker/pytorch_model.fp16.bin"
,
"safety_checker/model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.fp16.bin"
,
"vae/diffusion_pytorch_model.fp16.safetensors"
,
"text_encoder/pytorch_model.fp16.bin"
,
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin"
,
"unet/diffusion_pytorch_model.fp16.safetensors"
,
]
variant
=
"fp16"
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
variant
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment