Unverified Commit e4f6c379 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DiffusionPipeline] Deprecate not throwing error when loading non-existant variant (#4011)



* Deprecate variant nicely

* make style

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 98c9aac1
......@@ -1213,6 +1213,15 @@ class DiffusionPipeline(ConfigMixin):
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
if len(variant_filenames) == 0 and variant is not None:
deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False)
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import gc
import glob
import json
import os
import random
......@@ -56,6 +57,7 @@ from diffusers import (
UniPCMultistepScheduler,
logging,
)
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
......@@ -1361,6 +1363,29 @@ class PipelineFastTests(unittest.TestCase):
assert sd.config.safety_checker != (None, None)
assert sd.config.feature_extractor != (None, None)
def test_warning_no_variant_available(self):
variant = "fp16"
with self.assertWarns(FutureWarning) as warning_context:
cached_folder = StableDiffusionPipeline.download(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
)
assert "but no such modeling files are available" in str(warning_context.warning)
assert variant in str(warning_context.warning)
def get_all_filenames(directory):
filenames = glob.glob(directory + "/**", recursive=True)
filenames = [f for f in filenames if os.path.isfile(f)]
return filenames
filenames = get_all_filenames(str(cached_folder))
all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant)
# make sure that none of the model names are variant model names
assert len(variant_model_files) == 0
assert len(all_model_files) > 0
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment