Unverified Commit 09779cbb authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)

* [Bump version] 0.13

* Bump model up

* up
parent b0cc7c20
......@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
if isinstance(prompt, str):
......
......@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs. Raise error if not correct
......
......@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs)
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image
# 1. Check inputs
......
......@@ -23,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
......@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
......
......@@ -22,7 +22,6 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
......@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
......@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
steps_offset: int = 0,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
......
......@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
......@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
......@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: torch.FloatTensor,
generator=None,
return_dict: bool = True,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
......@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
......
......@@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
......@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
......@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
clip_sample: bool = True,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
......
......@@ -21,7 +21,6 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
......@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
......@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
......
......@@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
......@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
......@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type: str = "midpoint",
lower_order_final: bool = True,
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
......
......@@ -19,7 +19,6 @@ import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
......@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
def test_inference_predict_sample(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(prediction_type="sample")
......
......@@ -26,7 +26,6 @@ from diffusers import (
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import CaptureLogger
......@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
with CaptureLogger(logger) as cap_logger:
deprecate("remove this case", "0.13.0", "remove")
ddpm_3 = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.prediction_type == "sample"
assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88
assert ddpm_3.config.prediction_type == "sample"
# no warning should be thrown
assert cap_logger.out == ""
......
......@@ -45,7 +45,7 @@ from diffusers import (
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import CaptureLogger
......@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
sample = self.dummy_sample_deter
residual = 0.1 * self.dummy_sample_deter
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
assert (output - output_eps).abs().sum() < 1e-5
def test_time_indices(self):
for t in [0, 500, 999]:
self.check_over_forward(time_step=t)
......
......@@ -18,7 +18,7 @@ import unittest
from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import deprecate, is_flax_available
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
......@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_predict_epsilon_to_prediction_type(self):
deprecate("remove this test", "0.13.0", "remove")
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "epsilon"
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "sample"
@require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment