Unverified Commit 76c00c72 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

is_safetensors_compatible fix (#9741)

update
parent 0d9d98fe
...@@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No ...@@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components.setdefault(component, []) components.setdefault(component, [])
components[component].append(component_filename) components[component].append(component_filename)
# If there are no component folders check the main directory for safetensors files
if not components:
return any(".safetensors" in filename for filename in filenames)
# iterate over all files of a component # iterate over all files of a component
# check if safetensor files exist for that component # check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists # if variant is provided check if the variant of the safetensors exists
......
...@@ -197,6 +197,18 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -197,6 +197,18 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_no_components(self):
filenames = [
"diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_diffusers_is_compatible_no_components_only_variants(self):
filenames = [
"diffusion_pytorch_model.fp16.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))
class ProgressBarTests(unittest.TestCase): class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self): def get_dummy_components_image_generation(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment