"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "e4879676d187bae8452aae7aae3e54f3de1ad8e3"
Unverified Commit 214372aa authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix a regression in `is_safetensors_compatible` (#9234)

fix
parent 867e0c91
......@@ -89,7 +89,7 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
......@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names is not None:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files
components = {}
......
......@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
......
......@@ -116,6 +116,30 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))
def test_transformers_is_compatible_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment