Unverified Commit 214372aa authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix a regression in `is_safetensors_compatible` (#9234)

fix
parent 867e0c91
...@@ -89,7 +89,7 @@ for library in LOADABLE_CLASSES: ...@@ -89,7 +89,7 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, passed_components=None) -> bool: def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
""" """
Checking for safetensors compatibility: Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in - The model is safetensors compatible only if there is a safetensors file for each model component present in
...@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool: ...@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
extension is replaced with ".safetensors" extension is replaced with ".safetensors"
""" """
passed_components = passed_components or [] passed_components = passed_components or []
if folder_names is not None:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files # extract all components of the pipeline and their associated files
components = {} components = {}
......
...@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if ( if (
use_safetensors use_safetensors
and not allow_pickle and not allow_pickle
and not is_safetensors_compatible(model_filenames, passed_components=passed_components) and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
): ):
raise EnvironmentError( raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
) )
if from_flax: if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components): elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"] ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
......
...@@ -116,6 +116,30 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -116,6 +116,30 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
] ]
self.assertFalse(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))
def test_transformers_is_compatible_sharded(self): def test_transformers_is_compatible_sharded(self):
filenames = [ filenames = [
"text_encoder/pytorch_model.bin", "text_encoder/pytorch_model.bin",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment