Unverified Commit d52388f4 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Deprecate `predict_epsilon` (#1393)



* Adapt ddpm, ddpmsolver to prediction_type.

* Deprecate predict_epsilon in __init__.

* Bring FlaxDDIMScheduler up to date with DDIMScheduler.

* Set prediction_type as an ivar for consistency.

* Convert pipeline_ddpm

* Adapt tests.

* Adapt unconditional training script.

* Adapt BitDiffusion example.

* Add missing kwargs in dpmsolver_multistep

* Ugly workaround to accept deprecated predict_epsilon when loading
schedulers using from_pretrained.

* make style

* Remove import no longer in use.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Use config.prediction_type everywhere

* Add a couple of Flax prediction type tests.

* make style

* fix register deprecated arg
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent babfb8a0
...@@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step( ...@@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
predict_epsilon=True, prediction_type="epsilon",
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]: ) -> Union[DDPMSchedulerOutput, Tuple]:
...@@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step( ...@@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
predict_epsilon (`bool`): prediction_type (`str`, default `epsilon`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon. indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns: Returns:
...@@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step( ...@@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step(
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon: if prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: elif prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
else:
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
scale = self.bit_scale scale = self.bit_scale
......
...@@ -194,9 +194,10 @@ def parse_args(): ...@@ -194,9 +194,10 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--predict_epsilon", "--prediction_type",
action="store_true", type=str,
default=True, default="epsilon",
choices=["epsilon", "sample"],
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
) )
...@@ -256,13 +257,13 @@ def main(args): ...@@ -256,13 +257,13 @@ def main(args):
"UpBlock2D", "UpBlock2D",
), ),
) )
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_predict_epsilon: if accepts_prediction_type:
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps, num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule, beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon, prediction_type=args.prediction_type,
) )
else: else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
...@@ -365,9 +366,9 @@ def main(args): ...@@ -365,9 +366,9 @@ def main(args):
# Predict the noise residual # Predict the noise residual
model_output = model(noisy_images, timesteps).sample model_output = model(noisy_images, timesteps).sample
if args.predict_epsilon: if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights! loss = F.mse_loss(model_output, noise) # this could have different weights!
else: elif args.prediction_type == "sample":
alpha_t = _extract_into_tensor( alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
) )
...@@ -376,6 +377,8 @@ def main(args): ...@@ -376,6 +377,8 @@ def main(args):
model_output, clean_images, reduction="none" model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper ) # use SNR weighting from distillation paper
loss = loss.mean() loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
accelerator.backward(loss) accelerator.backward(loss)
......
...@@ -195,6 +195,11 @@ class ConfigMixin: ...@@ -195,6 +195,11 @@ class ConfigMixin:
if "dtype" in unused_kwargs: if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype") init_dict["dtype"] = unused_kwargs.pop("dtype")
if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
deprecate("remove this", "0.10.0", "remove")
predict_epsilon = unused_kwargs.pop("predict_epsilon")
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"
# Return model and optionally state and/or unused_kwargs # Return model and optionally state and/or unused_kwargs
model = cls(**init_dict) model = cls(**init_dict)
......
...@@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline): ...@@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
x = x + scale * grad x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim) x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: set prediction_type when instantiating the model
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory # apply conditions to the trajectory
......
...@@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline):
generated images. generated images.
""" """
message = ( message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`." " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None: if predict_epsilon is not None:
new_config = dict(self.scheduler.config) new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self.scheduler._internal_dict = FrozenDict(new_config) self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
...@@ -114,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -114,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1 # 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step( image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -106,6 +106,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -106,6 +106,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion. stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
""" """
...@@ -123,7 +126,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -123,7 +126,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -139,8 +151,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,8 +151,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
...@@ -261,17 +271,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -261,17 +271,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample": elif self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
elif self.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V # predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`" " `v_prediction`"
) )
......
...@@ -23,6 +23,7 @@ import flax ...@@ -23,6 +23,7 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin, FlaxSchedulerMixin,
...@@ -108,6 +109,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -108,6 +109,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion. stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
...@@ -125,7 +130,17 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -125,7 +130,17 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
...@@ -259,7 +274,19 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -259,7 +274,19 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. compute variance: "sigma_t(η)" -> see formula (16) # 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
......
...@@ -99,9 +99,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,9 +99,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`): prediction_type (`str`, default `epsilon`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
...@@ -116,8 +116,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,8 +116,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
predict_epsilon: bool = True, prediction_type: str = "epsilon",
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -241,13 +250,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,13 +250,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
message = ( message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`." " DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: if predict_epsilon is not None:
new_config = dict(self.config) new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config) self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
...@@ -265,10 +274,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -265,10 +274,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.predict_epsilon: if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: elif self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DDPMScheduler."
)
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
......
...@@ -103,9 +103,9 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -103,9 +103,9 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`): prediction_type (`str`, default `epsilon`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
...@@ -124,8 +124,17 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -124,8 +124,17 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
predict_epsilon: bool = True, prediction_type: str = "epsilon",
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -204,7 +213,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -204,7 +213,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep: int, timestep: int,
sample: jnp.ndarray, sample: jnp.ndarray,
key: random.KeyArray, key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True, return_dict: bool = True,
**kwargs, **kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]: ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
...@@ -227,13 +235,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -227,13 +235,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
message = ( message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`." " FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: if predict_epsilon is not None:
new_config = dict(self.config) new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config) self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
...@@ -251,10 +259,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -251,10 +259,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.predict_epsilon: if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: elif self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDDPMScheduler."
)
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -87,10 +87,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -87,10 +87,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`): solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`): prediction_type (`str`, default `epsilon`):
we currently support both the noise prediction model and the data prediction model. If the model predicts indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set `v-prediction` is not supported for this scheduler.
`predict_epsilon` to `False`.
thresholding (`bool`, default `False`): thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
...@@ -128,14 +127,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,14 +127,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
solver_order: int = 2, solver_order: int = 2,
predict_epsilon: bool = True, prediction_type: str = "epsilon",
thresholding: bool = False, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -221,11 +229,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -221,11 +229,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon: if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
else: elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DPMSolverMultistepScheduler."
)
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = torch.quantile( dynamic_max_val = torch.quantile(
...@@ -239,12 +253,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -239,12 +253,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
if self.config.predict_epsilon: if self.config.prediction_type == "epsilon":
return model_output return model_output
else: elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DPMSolverMultistepScheduler."
)
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
......
...@@ -23,6 +23,7 @@ import jax ...@@ -23,6 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin, FlaxSchedulerMixin,
...@@ -118,10 +119,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -118,10 +119,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`): solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`): prediction_type (`str`, default `epsilon`):
we currently support both the noise prediction model and the data prediction model. If the model predicts indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set `v-prediction` is not supported for this scheduler.
`predict_epsilon` to `False`.
thresholding (`bool`, default `False`): thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
...@@ -163,14 +163,23 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -163,14 +163,23 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
solver_order: int = 2, solver_order: int = 2,
predict_epsilon: bool = True, prediction_type: str = "epsilon",
thresholding: bool = False, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
**kwargs,
): ):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
...@@ -260,11 +269,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -260,11 +269,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon: if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
else: elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDPMSolverMultistepScheduler."
)
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = jnp.percentile( dynamic_max_val = jnp.percentile(
...@@ -277,12 +292,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -277,12 +292,17 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
if self.config.predict_epsilon: if self.config.prediction_type == "epsilon":
return model_output return model_output
else: elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDPMSolverMultistepScheduler."
)
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
......
...@@ -92,8 +92,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -92,8 +92,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
...@@ -232,14 +230,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -232,14 +230,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output pred_original_sample = sample - sigma_hat * model_output
elif self.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip # * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
) )
# 2. Convert to an ODE derivative # 2. Convert to an ODE derivative
......
...@@ -68,7 +68,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -68,7 +68,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_predict_epsilon(self): def test_inference_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.10.0", "remove")
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False) scheduler = DDPMScheduler(predict_epsilon=False)
...@@ -98,6 +98,35 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -98,6 +98,35 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
tolerance = 1e-2 if torch_device != "mps" else 3e-2 tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
def test_inference_predict_sample(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(prediction_type="sample")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = generator.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -26,6 +26,7 @@ from diffusers import ( ...@@ -26,6 +26,7 @@ from diffusers import (
logging, logging,
) )
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import CaptureLogger from diffusers.utils.testing_utils import CaptureLogger
...@@ -194,17 +195,27 @@ class ConfigTester(unittest.TestCase): ...@@ -194,17 +195,27 @@ class ConfigTester(unittest.TestCase):
ddpm = DDPMScheduler.from_pretrained( ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", "hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler", subfolder="scheduler",
predict_epsilon=False, prediction_type="sample",
beta_end=8, beta_end=8,
) )
with CaptureLogger(logger) as cap_logger_2: with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
with CaptureLogger(logger) as cap_logger:
deprecate("remove this case", "0.10.0", "remove")
ddpm_3 = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
assert ddpm.__class__ == DDPMScheduler assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False assert ddpm.config.prediction_type == "sample"
assert ddpm.config.beta_end == 8 assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88 assert ddpm_2.config.beta_start == 88
assert ddpm_3.config.prediction_type == "sample"
# no warning should be thrown # no warning should be thrown
assert cap_logger.out == "" assert cap_logger.out == ""
......
...@@ -20,7 +20,6 @@ import random ...@@ -20,7 +20,6 @@ import random
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from functools import partial
import numpy as np import numpy as np
import torch import torch
...@@ -332,14 +331,13 @@ class PipelineFastTests(unittest.TestCase): ...@@ -332,14 +331,13 @@ class PipelineFastTests(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[DDIMScheduler, DDIMPipeline, 32], [DDIMScheduler, DDIMPipeline, 32],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32], [DDPMScheduler, DDPMPipeline, 32],
[DDIMScheduler, DDIMPipeline, (32, 64)], [DDIMScheduler, DDIMPipeline, (32, 64)],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)], [DDPMScheduler, DDPMPipeline, (64, 32)],
] ]
) )
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32): def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
unet = self.dummy_uncond_unet(sample_size) unet = self.dummy_uncond_unet(sample_size)
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
scheduler = scheduler_fn() scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device) pipeline = pipeline_fn(unet, scheduler).to(torch_device)
......
...@@ -599,7 +599,12 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -599,7 +599,12 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_predict_epsilon(self): def test_prediction_type(self):
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
for predict_epsilon in [True, False]: for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon) self.check_over_configs(predict_epsilon=predict_epsilon)
...@@ -795,7 +800,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -795,7 +800,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"solver_order": 2, "solver_order": 2,
"predict_epsilon": True, "prediction_type": "epsilon",
"thresholding": False, "thresholding": False,
"sample_max_value": 1.0, "sample_max_value": 1.0,
"algorithm_type": "dpmsolver++", "algorithm_type": "dpmsolver++",
...@@ -921,10 +926,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -921,10 +926,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
for order in [1, 2, 3]: for order in [1, 2, 3]:
for solver_type in ["midpoint", "heun"]: for solver_type in ["midpoint", "heun"]:
for threshold in [0.5, 1.0, 2.0]: for threshold in [0.5, 1.0, 2.0]:
for predict_epsilon in [True, False]: for prediction_type in ["epsilon", "sample"]:
self.check_over_configs( self.check_over_configs(
thresholding=True, thresholding=True,
predict_epsilon=predict_epsilon, prediction_type=prediction_type,
sample_max_value=threshold, sample_max_value=threshold,
algorithm_type="dpmsolver++", algorithm_type="dpmsolver++",
solver_order=order, solver_order=order,
...@@ -935,17 +940,17 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -935,17 +940,17 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
for algorithm_type in ["dpmsolver", "dpmsolver++"]: for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for solver_type in ["midpoint", "heun"]: for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]: for order in [1, 2, 3]:
for predict_epsilon in [True, False]: for prediction_type in ["epsilon", "sample"]:
self.check_over_configs( self.check_over_configs(
solver_order=order, solver_order=order,
solver_type=solver_type, solver_type=solver_type,
predict_epsilon=predict_epsilon, prediction_type=prediction_type,
algorithm_type=algorithm_type, algorithm_type=algorithm_type,
) )
sample = self.full_loop( sample = self.full_loop(
solver_order=order, solver_order=order,
solver_type=solver_type, solver_type=solver_type,
predict_epsilon=predict_epsilon, prediction_type=prediction_type,
algorithm_type=algorithm_type, algorithm_type=algorithm_type,
) )
assert not torch.isnan(sample).any(), "Samples have nan numbers" assert not torch.isnan(sample).any(), "Samples have nan numbers"
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import is_flax_available from diffusers.utils import deprecate, is_flax_available
from diffusers.utils.testing_utils import require_flax from diffusers.utils.testing_utils import require_flax
...@@ -599,6 +599,26 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -599,6 +599,26 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
assert abs(result_sum - 149.0784) < 1e-2 assert abs(result_sum - 149.0784) < 1e-2
assert abs(result_mean - 0.1941) < 1e-3 assert abs(result_mean - 0.1941) < 1e-3
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_predict_epsilon_to_prediction_type(self):
deprecate("remove this test", "0.10.0", "remove")
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(predict_epsilon=True)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "epsilon"
scheduler_config = self.get_scheduler_config(predict_epsilon=False)
scheduler = scheduler_class.from_config(scheduler_config)
assert scheduler.prediction_type == "sample"
@require_flax @require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment