Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
826f4350
Unverified
Commit
826f4350
authored
May 26, 2025
by
Dhruv Nair
Committed by
GitHub
May 26, 2025
Browse files
Fix mixed variant downloading (#11611)
* update * update
parent
4af76d0d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
34 deletions
+42
-34
src/diffusers/pipelines/pipeline_loading_utils.py
src/diffusers/pipelines/pipeline_loading_utils.py
+9
-3
tests/pipelines/test_pipeline_utils.py
tests/pipelines/test_pipeline_utils.py
+14
-0
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+19
-31
No files found.
src/diffusers/pipelines/pipeline_loading_utils.py
View file @
826f4350
...
...
@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components
[
component
].
append
(
component_filename
)
# If there are no component folders check the main directory for safetensors files
filtered_filenames
=
set
()
if
not
components
:
if
variant
is
not
None
:
filtered_filenames
=
filter_with_regex
(
filenames
,
variant_file_re
)
else
:
# If no variant filenames exist check if non-variant files are available
if
not
filtered_filenames
:
filtered_filenames
=
filter_with_regex
(
filenames
,
non_variant_file_re
)
return
any
(
".safetensors"
in
filename
for
filename
in
filtered_filenames
)
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for
component
,
component_filenames
in
components
.
items
():
matches
=
[]
filtered_component_filenames
=
set
()
# if variant is provided check if the variant of the safetensors exists
if
variant
is
not
None
:
filtered_component_filenames
=
filter_with_regex
(
component_filenames
,
variant_file_re
)
else
:
# if variant safetensor files do not exist check for non-variants
if
not
filtered_component_filenames
:
filtered_component_filenames
=
filter_with_regex
(
component_filenames
,
non_variant_file_re
)
for
component_filename
in
filtered_component_filenames
:
filename
,
extension
=
os
.
path
.
splitext
(
component_filename
)
...
...
tests/pipelines/test_pipeline_utils.py
View file @
826f4350
...
...
@@ -217,6 +217,20 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
))
def
test_is_compatible_mixed_variants
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.safetensors"
,
]
self
.
assertTrue
(
is_safetensors_compatible
(
filenames
,
variant
=
"fp16"
))
def
test_is_compatible_variant_and_non_safetensors
(
self
):
filenames
=
[
"unet/diffusion_pytorch_model.fp16.safetensors"
,
"vae/diffusion_pytorch_model.bin"
,
]
self
.
assertFalse
(
is_safetensors_compatible
(
filenames
,
variant
=
"fp16"
))
class
VariantCompatibleSiblingsTest
(
unittest
.
TestCase
):
def
test_only_non_variants_downloaded
(
self
):
...
...
tests/pipelines/test_pipelines.py
View file @
826f4350
...
...
@@ -538,16 +538,6 @@ class DownloadTests(unittest.TestCase):
variant
=
"no_ema"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
if
use_safetensors
:
with
self
.
assertRaises
(
OSError
)
as
error_context
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-all-variants"
,
cache_dir
=
tmpdirname
,
variant
=
variant
,
use_safetensors
=
use_safetensors
,
)
assert
"Could not find the necessary `safetensors` weights"
in
str
(
error_context
.
exception
)
else
:
tmpdirname
=
StableDiffusionPipeline
.
download
(
"hf-internal-testing/stable-diffusion-all-variants"
,
cache_dir
=
tmpdirname
,
...
...
@@ -566,9 +556,7 @@ class DownloadTests(unittest.TestCase):
assert
f
"diffusion_pytorch_model.
{
variant
}{
this_format
}
"
in
unet_files
assert
len
([
f
for
f
in
files
if
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)])
==
1
# vae, safety_checker and text_encoder should have no variant
assert
(
sum
(
f
.
endswith
(
this_format
)
and
not
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)
for
f
in
files
)
==
3
)
assert
sum
(
f
.
endswith
(
this_format
)
and
not
f
.
endswith
(
f
"
{
variant
}{
this_format
}
"
)
for
f
in
files
)
==
3
assert
not
any
(
f
.
endswith
(
other_format
)
for
f
in
files
)
def
test_download_variants_with_sharded_checkpoints
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment