Unverified Commit 09779cbb authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)

* [Bump version] 0.13

* Bump model up

* up
parent b0cc7c20
...@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
message = "Please use `image` instead of `init_image`." message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image image = init_image or image
if isinstance(prompt, str): if isinstance(prompt, str):
......
...@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
message = "Please use `image` instead of `init_image`." message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image image = init_image or image
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
......
...@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
message = "Please use `image` instead of `init_image`." message = "Please use `image` instead of `init_image`."
init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
image = init_image or image image = init_image or image
# 1. Check inputs # 1. Check inputs
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
...@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1 order = 1
@register_to_config @register_to_config
...@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
......
...@@ -22,7 +22,6 @@ import flax ...@@ -22,7 +22,6 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
CommonSchedulerState, CommonSchedulerState,
FlaxKarrasDiffusionSchedulers, FlaxKarrasDiffusionSchedulers,
...@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype dtype: jnp.dtype
...@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
......
...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union ...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
...@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1 order = 1
@register_to_config @register_to_config
...@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: torch.FloatTensor, sample: torch.FloatTensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]: ) -> Union[DDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
...@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
......
...@@ -22,7 +22,6 @@ import jax ...@@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
CommonSchedulerState, CommonSchedulerState,
FlaxKarrasDiffusionSchedulers, FlaxKarrasDiffusionSchedulers,
...@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype dtype: jnp.dtype
...@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
clip_sample: bool = True, clip_sample: bool = True,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
......
...@@ -21,7 +21,6 @@ import numpy as np ...@@ -21,7 +21,6 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1 order = 1
@register_to_config @register_to_config
...@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
......
...@@ -22,7 +22,6 @@ import jax ...@@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
CommonSchedulerState, CommonSchedulerState,
FlaxKarrasDiffusionSchedulers, FlaxKarrasDiffusionSchedulers,
...@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype dtype: jnp.dtype
...@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
self.dtype = dtype self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState:
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import torch import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
...@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase): ...@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
def test_inference_predict_sample(self): def test_inference_predict_sample(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(prediction_type="sample") scheduler = DDPMScheduler(prediction_type="sample")
......
...@@ -26,7 +26,6 @@ from diffusers import ( ...@@ -26,7 +26,6 @@ from diffusers import (
logging, logging,
) )
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import CaptureLogger from diffusers.utils.testing_utils import CaptureLogger
...@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase): ...@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
with CaptureLogger(logger) as cap_logger_2: with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
with CaptureLogger(logger) as cap_logger:
deprecate("remove this case", "0.13.0", "remove")
ddpm_3 = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
assert ddpm.__class__ == DDPMScheduler assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.prediction_type == "sample" assert ddpm.config.prediction_type == "sample"
assert ddpm.config.beta_end == 8 assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88 assert ddpm_2.config.beta_start == 88
assert ddpm_3.config.prediction_type == "sample"
# no warning should be thrown # no warning should be thrown
assert cap_logger.out == "" assert cap_logger.out == ""
......
...@@ -45,7 +45,7 @@ from diffusers import ( ...@@ -45,7 +45,7 @@ from diffusers import (
) )
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device from diffusers.utils import torch_device
from diffusers.utils.testing_utils import CaptureLogger from diffusers.utils.testing_utils import CaptureLogger
...@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for prediction_type in ["epsilon", "sample", "v_prediction"]: for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
sample = self.dummy_sample_deter
residual = 0.1 * self.dummy_sample_deter
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
assert (output - output_eps).abs().sum() < 1e-5
def test_time_indices(self): def test_time_indices(self):
for t in [0, 500, 999]: for t in [0, 500, 999]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import deprecate, is_flax_available from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax from diffusers.utils.testing_utils import require_flax
...@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
for prediction_type in ["epsilon", "sample", "v_prediction"]: for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.13.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_predict_epsilon_to_prediction_type(self):
deprecate("remove this test", "0.13.0", "remove")
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "epsilon"
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "sample"
@require_flax @require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment