Unverified Commit 826f4350 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix mixed variant downloading (#11611)

* update

* update
parent 4af76d0d
...@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No ...@@ -146,21 +146,27 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components[component].append(component_filename) components[component].append(component_filename)
# If there are no component folders check the main directory for safetensors files # If there are no component folders check the main directory for safetensors files
filtered_filenames = set()
if not components: if not components:
if variant is not None: if variant is not None:
filtered_filenames = filter_with_regex(filenames, variant_file_re) filtered_filenames = filter_with_regex(filenames, variant_file_re)
else:
# If no variant filenames exist check if non-variant files are available
if not filtered_filenames:
filtered_filenames = filter_with_regex(filenames, non_variant_file_re) filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
return any(".safetensors" in filename for filename in filtered_filenames) return any(".safetensors" in filename for filename in filtered_filenames)
# iterate over all files of a component # iterate over all files of a component
# check if safetensor files exist for that component # check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items(): for component, component_filenames in components.items():
matches = [] matches = []
filtered_component_filenames = set()
# if variant is provided check if the variant of the safetensors exists
if variant is not None: if variant is not None:
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re) filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
else:
# if variant safetensor files do not exist check for non-variants
if not filtered_component_filenames:
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re) filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
for component_filename in filtered_component_filenames: for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename) filename, extension = os.path.splitext(component_filename)
......
...@@ -217,6 +217,20 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -217,6 +217,20 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
] ]
self.assertFalse(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
def test_is_compatible_mixed_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_is_compatible_variant_and_non_safetensors(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))
class VariantCompatibleSiblingsTest(unittest.TestCase): class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self): def test_only_non_variants_downloaded(self):
......
...@@ -538,38 +538,26 @@ class DownloadTests(unittest.TestCase): ...@@ -538,38 +538,26 @@ class DownloadTests(unittest.TestCase):
variant = "no_ema" variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
if use_safetensors: tmpdirname = StableDiffusionPipeline.download(
with self.assertRaises(OSError) as error_context: "hf-internal-testing/stable-diffusion-all-variants",
tmpdirname = StableDiffusionPipeline.download( cache_dir=tmpdirname,
"hf-internal-testing/stable-diffusion-all-variants", variant=variant,
cache_dir=tmpdirname, use_safetensors=use_safetensors,
variant=variant, )
use_safetensors=use_safetensors, all_root_files = [t[-1] for t in os.walk(tmpdirname)]
) files = [item for sublist in all_root_files for item in sublist]
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
else:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(tmpdirname, "unet")) unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
# Some of the downloaded files should be a non-variant file, check: # Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}" assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant # only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant # vae, safety_checker and text_encoder should have no variant
assert ( assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert not any(f.endswith(other_format) for f in files)
)
assert not any(f.endswith(other_format) for f in files)
def test_download_variants_with_sharded_checkpoints(self): def test_download_variants_with_sharded_checkpoints(self):
# Here we test for downloading of "variant" files belonging to the `unet` and # Here we test for downloading of "variant" files belonging to the `unet` and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment