Unverified Commit 856dad57 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

is_safetensors_compatible refactor (#2499)

* is_safetensors_compatible refactor

* files list comma
parent a75ac3fa
...@@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput): ...@@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):
def is_safetensors_compatible(filenames, variant=None) -> bool: def is_safetensors_compatible(filenames, variant=None) -> bool:
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) """
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
for pt_filename in pt_filenames: files to know which safetensors files are needed.
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else "" - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
prefix, raw = os.path.split(pt_filename)
if raw == f"pytorch_model{_variant}.bin": Converting default pytorch serialized filenames to safetensors serialized filenames:
# transformers specific - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors") - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames = []
sf_filenames = set()
for filename in filenames:
_, extension = os.path.splitext(filename)
if extension == ".bin":
pt_filenames.append(filename)
elif extension == ".safetensors":
sf_filenames.add(filename)
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)
if filename == "pytorch_model":
filename = "model"
elif filename == f"pytorch_model.{variant}":
filename = f"model.{variant}"
else: else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors" filename = filename
if is_safetensors_compatible and sf_filename not in filenames:
logger.warning(f"{sf_filename} not found") expected_sf_filename = os.path.join(path, filename)
is_safetensors_compatible = False expected_sf_filename = f"{expected_sf_filename}.safetensors"
return is_safetensors_compatible
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
return True
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]: def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
......
import unittest
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
class IsSafetensorsCompatibleTests(unittest.TestCase):
def test_all_is_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_compatible(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
# Removed: 'unet/diffusion_pytorch_model.safetensors',
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_transformer_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
# Removed: 'text_encoder/model.safetensors',
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_all_is_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_diffusers_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_compatible_variant(self):
filenames = [
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment