Unverified Commit c18941b0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Better scheduler docs] Improve usage examples of schedulers (#890)



* [Better scheduler docs] Improve usage examples of schedulers

* finish

* fix warnings and add test

* finish

* more replacements

* adapt fast tests hf token

* correct more

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Integrate compatibility with euler
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent a1ea8c01
...@@ -42,6 +42,8 @@ jobs: ...@@ -42,6 +42,8 @@ jobs:
python utils/print_env.py python utils/print_env.py
- name: Run all fast tests on CPU - name: Run all fast tests on CPU
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: | run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/ python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
...@@ -91,6 +93,8 @@ jobs: ...@@ -91,6 +93,8 @@ jobs:
- name: Run all fast tests on MPS - name: Run all fast tests on MPS
shell: arch -arch arm64 bash {0} shell: arch -arch arm64 bash {0}
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: | run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
......
...@@ -142,11 +142,7 @@ it before the pipeline and pass it to `from_pretrained`. ...@@ -142,11 +142,7 @@ it before the pipeline and pass it to `from_pretrained`.
```python ```python
from diffusers import LMSDiscreteScheduler from diffusers import LMSDiscreteScheduler
lms = LMSDiscreteScheduler( lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
......
...@@ -121,7 +121,7 @@ you could use it as follows: ...@@ -121,7 +121,7 @@ you could use it as follows:
```python ```python
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") >>> scheduler = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
>>> generator = StableDiffusionPipeline.from_pretrained( >>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
......
...@@ -469,9 +469,7 @@ def main(args): ...@@ -469,9 +469,7 @@ def main(args):
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
......
...@@ -372,11 +372,7 @@ def main(): ...@@ -372,11 +372,7 @@ def main():
weight_decay=args.adam_weight_decay, weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
...@@ -609,9 +605,7 @@ def main(): ...@@ -609,9 +605,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
......
...@@ -419,13 +419,7 @@ def main(): ...@@ -419,13 +419,7 @@ def main():
eps=args.adam_epsilon, eps=args.adam_epsilon,
) )
# TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
train_dataset = TextualInversionDataset( train_dataset = TextualInversionDataset(
data_root=args.train_data_dir, data_root=args.train_data_dir,
...@@ -558,9 +552,7 @@ def main(): ...@@ -558,9 +552,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
""" ConfigMixin base class and utilities.""" """ ConfigMixin base class and utilities."""
import dataclasses import dataclasses
import functools import functools
import importlib
import inspect import inspect
import json import json
import os import os
...@@ -48,9 +49,13 @@ class ConfigMixin: ...@@ -48,9 +49,13 @@ class ConfigMixin:
[`~ConfigMixin.save_config`] (should be overridden by parent class). [`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class). overridden by parent class).
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
""" """
config_name = None config_name = None
ignore_for_config = [] ignore_for_config = []
_compatible_classes = []
def register_to_config(self, **kwargs): def register_to_config(self, **kwargs):
if self.config_name is None: if self.config_name is None:
...@@ -280,9 +285,14 @@ class ConfigMixin: ...@@ -280,9 +285,14 @@ class ConfigMixin:
return config_dict return config_dict
@staticmethod
def _get_init_keys(cls):
return set(dict(inspect.signature(cls.__init__).parameters).keys())
@classmethod @classmethod
def extract_init_dict(cls, config_dict, **kwargs): def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) # 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys.remove("self") expected_keys.remove("self")
# remove general kwargs if present in dict # remove general kwargs if present in dict
if "kwargs" in expected_keys: if "kwargs" in expected_keys:
...@@ -292,9 +302,36 @@ class ConfigMixin: ...@@ -292,9 +302,36 @@ class ConfigMixin:
for arg in cls._flax_internal_args: for arg in cls._flax_internal_args:
expected_keys.remove(arg) expected_keys.remove(arg)
# 2. Remove attributes that cannot be expected from expected config attributes
# remove keys to be ignored # remove keys to be ignored
if len(cls.ignore_for_config) > 0: if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config) expected_keys = expected_keys - set(cls.ignore_for_config)
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
# remove attributes from compatible classes that orig cannot expect
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
# filter out None potentially undefined dummy classes
compatible_classes = [c for c in compatible_classes if c is not None]
expected_keys_comp_cls = set()
for c in compatible_classes:
expected_keys_c = cls._get_init_keys(c)
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__:
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {} init_dict = {}
for key in expected_keys: for key in expected_keys:
if key in kwargs: if key in kwargs:
...@@ -304,8 +341,7 @@ class ConfigMixin: ...@@ -304,8 +341,7 @@ class ConfigMixin:
# use value from config dict # use value from config dict
init_dict[key] = config_dict.pop(key) init_dict[key] = config_dict.pop(key)
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} # 4. Give nice warning if unexpected values have been passed
if len(config_dict) > 0: if len(config_dict) > 0:
logger.warning( logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, " f"The config attributes {config_dict} were passed to {cls.__name__}, "
...@@ -313,14 +349,16 @@ class ConfigMixin: ...@@ -313,14 +349,16 @@ class ConfigMixin:
f"{cls.config_name} configuration file." f"{cls.config_name} configuration file."
) )
unused_kwargs = {**config_dict, **kwargs} # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.info( logger.info(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
) )
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
return init_dict, unused_kwargs return init_dict, unused_kwargs
@classmethod @classmethod
......
...@@ -272,7 +272,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -272,7 +272,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> # Download pipeline, but overwrite scheduler >>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) >>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
``` ```
""" """
......
...@@ -360,7 +360,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -360,7 +360,7 @@ class DiffusionPipeline(ConfigMixin):
>>> # Download pipeline, but overwrite scheduler >>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
``` ```
""" """
...@@ -602,7 +602,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -602,7 +602,7 @@ class DiffusionPipeline(ConfigMixin):
... StableDiffusionInpaintPipeline, ... StableDiffusionInpaintPipeline,
... ) ... )
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") >>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components) >>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components) >>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
``` ```
......
...@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png") ...@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
...@@ -91,11 +91,7 @@ image.save("astronaut_rides_horse.png") ...@@ -91,11 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler( lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -36,6 +37,34 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -36,6 +37,34 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
......
...@@ -90,6 +90,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -90,6 +90,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......
...@@ -104,6 +104,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -104,6 +104,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......
...@@ -80,6 +80,19 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -80,6 +80,19 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......
...@@ -91,6 +91,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -91,6 +91,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......
...@@ -90,6 +90,19 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -90,6 +90,19 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......
...@@ -96,6 +96,19 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -96,6 +96,19 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None: if safety_checker is None:
logger.warn( logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......
...@@ -109,6 +109,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -109,6 +109,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"PNDMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -102,6 +102,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,6 +102,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -67,6 +67,14 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -67,6 +67,14 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment