Unverified Commit 249d9bc0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Scheduler] Move predict epsilon to init (#1155)



* [Scheduler] Move predict epsilon to init

* up

* uP

* uP

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* up
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 5786b0e2
import argparse import argparse
import inspect
import math import math
import os import os
from pathlib import Path from pathlib import Path
...@@ -190,10 +191,10 @@ def parse_args(): ...@@ -190,10 +191,10 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--predict_mode", "--predict_epsilon",
type=str, action="store_true",
default="eps", default=True,
help="What the model should predict. 'eps' to predict error, 'x0' to directly predict reconstruction", help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
) )
parser.add_argument("--ddpm_num_steps", type=int, default=1000) parser.add_argument("--ddpm_num_steps", type=int, default=1000)
...@@ -252,7 +253,17 @@ def main(args): ...@@ -252,7 +253,17 @@ def main(args):
"UpBlock2D", "UpBlock2D",
), ),
) )
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_predict_epsilon:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
...@@ -351,9 +362,9 @@ def main(args): ...@@ -351,9 +362,9 @@ def main(args):
# Predict the noise residual # Predict the noise residual
model_output = model(noisy_images, timesteps).sample model_output = model(noisy_images, timesteps).sample
if args.predict_mode == "eps": if args.predict_epsilon:
loss = F.mse_loss(model_output, noise) # this could have different weights! loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.predict_mode == "x0": else:
alpha_t = _extract_into_tensor( alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
) )
...@@ -401,7 +412,6 @@ def main(args): ...@@ -401,7 +412,6 @@ def main(args):
generator=generator, generator=generator,
batch_size=args.eval_batch_size, batch_size=args.eval_batch_size,
output_type="numpy", output_type="numpy",
predict_epsilon=args.predict_mode == "eps",
).images ).images
# denormalize the images and save to tensorboard # denormalize the images and save to tensorboard
......
...@@ -334,6 +334,11 @@ class ConfigMixin: ...@@ -334,6 +334,11 @@ class ConfigMixin:
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {} init_dict = {}
for key in expected_keys: for key in expected_keys:
# if config param is passed to kwarg and is present in config dict
# it should overwrite existing config dict key
if key in kwargs and key in config_dict:
config_dict[key] = kwargs.pop(key)
if key in kwargs: if key in kwargs:
# overwrite key # overwrite key
init_dict[key] = kwargs.pop(key) init_dict[key] = kwargs.pop(key)
......
...@@ -18,7 +18,9 @@ from typing import Optional, Tuple, Union ...@@ -18,7 +18,9 @@ from typing import Optional, Tuple, Union
import torch import torch
from ...configuration_utils import FrozenDict
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate
class DDPMPipeline(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
...@@ -45,7 +47,6 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -45,7 +47,6 @@ class DDPMPipeline(DiffusionPipeline):
num_inference_steps: int = 1000, num_inference_steps: int = 1000,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
predict_epsilon: bool = True,
**kwargs, **kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
r""" r"""
...@@ -69,6 +70,16 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -69,6 +70,16 @@ class DDPMPipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images. generated images.
""" """
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
......
...@@ -21,8 +21,8 @@ from typing import Optional, Tuple, Union ...@@ -21,8 +21,8 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,6 +99,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
""" """
...@@ -121,6 +123,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,6 +123,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
predict_epsilon: bool = True,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
...@@ -221,9 +224,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -221,9 +224,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
predict_epsilon=True,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[DDPMSchedulerOutput, Tuple]: ) -> Union[DDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
...@@ -234,8 +237,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -234,8 +237,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
...@@ -245,6 +246,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -245,6 +246,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
...@@ -260,7 +271,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -260,7 +271,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon: if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: else:
pred_original_sample = model_output pred_original_sample = model_output
......
...@@ -22,7 +22,8 @@ import flax ...@@ -22,7 +22,8 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
...@@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -97,7 +98,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
""" """
...@@ -115,6 +117,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -115,6 +117,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
predict_epsilon: bool = True,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
...@@ -196,6 +199,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -196,6 +199,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
key: random.KeyArray, key: random.KeyArray,
predict_epsilon: bool = True, predict_epsilon: bool = True,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]: ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
...@@ -208,8 +212,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -208,8 +212,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key. key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
Returns: Returns:
...@@ -217,6 +219,16 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -217,6 +219,16 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`tuple`. When returning a tuple, the first element is the sample tensor. `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
...@@ -232,7 +244,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -232,7 +244,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon: if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else: else:
pred_original_sample = model_output pred_original_sample = model_output
......
...@@ -42,7 +42,6 @@ class CustomLocalPipeline(DiffusionPipeline): ...@@ -42,7 +42,6 @@ class CustomLocalPipeline(DiffusionPipeline):
self, self,
batch_size: int = 1, batch_size: int = 1,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -89,7 +88,7 @@ class CustomLocalPipeline(DiffusionPipeline): ...@@ -89,7 +88,7 @@ class CustomLocalPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta # 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample image = self.scheduler.step(model_output, t, image).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import torch import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import require_torch, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -28,8 +29,74 @@ torch.backends.cuda.matmul.allow_tf32 = False ...@@ -28,8 +29,74 @@ torch.backends.cuda.matmul.allow_tf32 = False
class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# FIXME: add fast tests @property
pass def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model
def test_inference(self):
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler(predict_epsilon=False)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_eps_slice = image_eps[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
@slow @slow
......
...@@ -21,6 +21,7 @@ import unittest ...@@ -21,6 +21,7 @@ import unittest
import diffusers import diffusers
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
...@@ -291,6 +292,29 @@ class ConfigTester(unittest.TestCase): ...@@ -291,6 +292,29 @@ class ConfigTester(unittest.TestCase):
# no warning should be thrown # no warning should be thrown
assert cap_logger.out == "" assert cap_logger.out == ""
def test_overwrite_config_on_load(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_config(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
beta_end=8,
)
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
assert ddpm.config.beta_end == 8
assert ddpm_2.config.beta_start == 88
# no warning should be thrown
assert cap_logger.out == ""
assert cap_logger_2.out == ""
def test_load_dpmsolver(self): def test_load_dpmsolver(self):
logger = logging.get_logger("diffusers.configuration_utils") logger = logging.get_logger("diffusers.configuration_utils")
......
...@@ -107,6 +107,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -107,6 +107,7 @@ class CustomPipelineTests(unittest.TestCase):
images, output_str = pipeline(num_inference_steps=2, output_type="np") images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert images[0].shape == (1, 32, 32, 3) assert images[0].shape == (1, 32, 32, 3)
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test" assert output_str == "This is a test"
......
...@@ -33,7 +33,7 @@ from diffusers import ( ...@@ -33,7 +33,7 @@ from diffusers import (
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
VQDiffusionScheduler, VQDiffusionScheduler,
) )
from diffusers.utils import torch_device from diffusers.utils import deprecate, torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -393,6 +393,34 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -393,6 +393,34 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_predict_epsilon(self):
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
sample = self.dummy_sample_deter
residual = 0.1 * self.dummy_sample_deter
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
output = scheduler.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.Generator().manual_seed(0)
output_eps = scheduler_eps.step(residual, time_step, sample, predict_epsilon=False, **kwargs).prev_sample
assert (output - output_eps).abs().sum() < 1e-5
def test_time_indices(self): def test_time_indices(self):
for t in [0, 500, 999]: for t in [0, 500, 999]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment