Unverified Commit b5c2050a authored by kaixuanliu's avatar kaixuanliu Committed by GitHub
Browse files

Fix bug when `variant` and `safetensor` file does not match (#11587)



* Apply style fixes

* init test
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* adjust
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* add the variant check when there are no component folders
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* update related test cases
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* update related unit test cases
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* adjust
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* Apply style fixes

---------
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 7ae546f8
...@@ -92,7 +92,7 @@ for library in LOADABLE_CLASSES: ...@@ -92,7 +92,7 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool: def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
""" """
Checking for safetensors compatibility: Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in - The model is safetensors compatible only if there is a safetensors file for each model component present in
...@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No ...@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors" extension is replaced with ".safetensors"
""" """
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = r"\d{5}-of-\d{5}"
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
non_variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
passed_components = passed_components or [] passed_components = passed_components or []
if folder_names: if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
...@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No ...@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
# If there are no component folders check the main directory for safetensors files # If there are no component folders check the main directory for safetensors files
if not components: if not components:
return any(".safetensors" in filename for filename in filenames) if variant is not None:
filtered_filenames = filter_with_regex(filenames, variant_file_re)
else:
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
return any(".safetensors" in filename for filename in filtered_filenames)
# iterate over all files of a component # iterate over all files of a component
# check if safetensor files exist for that component # check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists # if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items(): for component, component_filenames in components.items():
matches = [] matches = []
for component_filename in component_filenames: if variant is not None:
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
else:
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename) filename, extension = os.path.splitext(component_filename)
match_exists = extension == ".safetensors" match_exists = extension == ".safetensors"
...@@ -159,6 +192,10 @@ def filter_model_files(filenames): ...@@ -159,6 +192,10 @@ def filter_model_files(filenames):
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [ weight_names = [
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -207,9 +244,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) - ...@@ -207,9 +244,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
# interested in the extension name # interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
# Group files by component # Group files by component
components = {} components = {}
for filename in filenames: for filename in filenames:
...@@ -997,7 +1031,7 @@ def _get_ignore_patterns( ...@@ -997,7 +1031,7 @@ def _get_ignore_patterns(
use_safetensors use_safetensors
and not allow_pickle and not allow_pickle
and not is_safetensors_compatible( and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
) )
): ):
raise EnvironmentError( raise EnvironmentError(
...@@ -1008,7 +1042,7 @@ def _get_ignore_patterns( ...@@ -1008,7 +1042,7 @@ def _get_ignore_patterns(
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible( elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
): ):
ignore_patterns = ["*.bin", "*.msgpack"] ignore_patterns = ["*.bin", "*.msgpack"]
......
...@@ -87,21 +87,24 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -87,21 +87,24 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant(self): def test_diffusers_model_is_compatible_variant(self):
filenames = [ filenames = [
"unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant_mixed(self): def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [ filenames = [
"unet/diffusion_pytorch_model.bin", "unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_not_compatible_variant(self): def test_diffusers_model_is_not_compatible_variant(self):
filenames = [ filenames = [
...@@ -121,7 +124,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -121,7 +124,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"text_encoder/pytorch_model.fp16.bin", "text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors", "text_encoder/model.fp16.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_transformer_model_is_not_compatible_variant(self): def test_transformer_model_is_not_compatible_variant(self):
filenames = [ filenames = [
...@@ -145,7 +149,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -145,7 +149,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin", "unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"})) self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
def test_transformer_model_is_not_compatible_variant_extra_folder(self): def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [ filenames = [
...@@ -173,7 +178,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -173,7 +178,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"text_encoder/model.fp16-00001-of-00002.safetensors", "text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors", "text_encoder/model.fp16-00001-of-00002.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_sharded(self): def test_diffusers_is_compatible_sharded(self):
filenames = [ filenames = [
...@@ -189,13 +195,15 @@ class IsSafetensorsCompatibleTests(unittest.TestCase): ...@@ -189,13 +195,15 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors", "unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_only_variants(self): def test_diffusers_is_compatible_only_variants(self):
filenames = [ filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors", "unet/diffusion_pytorch_model.fp16.safetensors",
] ]
self.assertTrue(is_safetensors_compatible(filenames)) self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_no_components(self): def test_diffusers_is_compatible_no_components(self):
filenames = [ filenames = [
......
...@@ -538,26 +538,38 @@ class DownloadTests(unittest.TestCase): ...@@ -538,26 +538,38 @@ class DownloadTests(unittest.TestCase):
variant = "no_ema" variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download( if use_safetensors:
"hf-internal-testing/stable-diffusion-all-variants", with self.assertRaises(OSError) as error_context:
cache_dir=tmpdirname, tmpdirname = StableDiffusionPipeline.download(
variant=variant, "hf-internal-testing/stable-diffusion-all-variants",
use_safetensors=use_safetensors, cache_dir=tmpdirname,
) variant=variant,
all_root_files = [t[-1] for t in os.walk(tmpdirname)] use_safetensors=use_safetensors,
files = [item for sublist in all_root_files for item in sublist] )
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
unet_files = os.listdir(os.path.join(tmpdirname, "unet")) else:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
# Some of the downloaded files should be a non-variant file, check: unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}" # Some of the downloaded files should be a non-variant file, check:
# only unet has "no_ema" variant # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1 # only unet has "no_ema" variant
# vae, safety_checker and text_encoder should have no variant assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
assert not any(f.endswith(other_format) for f in files) # vae, safety_checker and text_encoder should have no variant
assert (
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
)
assert not any(f.endswith(other_format) for f in files)
def test_download_variants_with_sharded_checkpoints(self): def test_download_variants_with_sharded_checkpoints(self):
# Here we test for downloading of "variant" files belonging to the `unet` and # Here we test for downloading of "variant" files belonging to the `unet` and
...@@ -588,20 +600,17 @@ class DownloadTests(unittest.TestCase): ...@@ -588,20 +600,17 @@ class DownloadTests(unittest.TestCase):
logger = logging.get_logger("diffusers.pipelines.pipeline_utils") logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
for is_local in [True, False]: with CaptureLogger(logger) as cap_logger:
with CaptureLogger(logger) as cap_logger: with tempfile.TemporaryDirectory() as tmpdirname:
with tempfile.TemporaryDirectory() as tmpdirname: local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
local_repo_id = repo_id
if is_local:
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
_ = DiffusionPipeline.from_pretrained( _ = DiffusionPipeline.from_pretrained(
local_repo_id, local_repo_id,
safety_checker=None, safety_checker=None,
variant="fp16", variant="fp16",
use_safetensors=True, use_safetensors=True,
) )
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
def test_download_safetensors_only_variant_exists_for_model(self): def test_download_safetensors_only_variant_exists_for_model(self):
variant = None variant = None
...@@ -616,7 +625,7 @@ class DownloadTests(unittest.TestCase): ...@@ -616,7 +625,7 @@ class DownloadTests(unittest.TestCase):
variant=variant, variant=variant,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
) )
assert "Error no file name" in str(error_context.exception) assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
# text encoder has fp16 variants so we can load it # text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -675,7 +684,7 @@ class DownloadTests(unittest.TestCase): ...@@ -675,7 +684,7 @@ class DownloadTests(unittest.TestCase):
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
) )
assert "Error no file name" in str(error_context.exception) assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
def test_download_bin_variant_does_not_exist_for_model(self): def test_download_bin_variant_does_not_exist_for_model(self):
variant = "no_ema" variant = "no_ema"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment