Unverified Commit 8263cf00 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

refactor DPMSolverMultistepScheduler using sigmas (#4986)




---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 74e43a4f
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -186,6 +187,14 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,6 +187,14 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
...@@ -225,17 +234,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -225,17 +234,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
self.sigmas = torch.from_numpy(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
# when num_inference_steps == num_train_timesteps, we can end up with sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -245,6 +253,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -245,6 +253,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
""" """
...@@ -280,8 +291,57 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -280,8 +291,57 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Convert the model output to the corresponding type the DEIS algorithm needs. Convert the model output to the corresponding type the DEIS algorithm needs.
...@@ -298,13 +358,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -298,13 +358,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
...@@ -316,7 +389,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -316,7 +389,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = self._threshold_sample(x0_pred) x0_pred = self._threshold_sample(x0_pred)
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
return (sample - alpha_t * x0_pred) / sigma_t return (sample - alpha_t * x0_pred) / sigma_t
else: else:
raise NotImplementedError("only support log-rho multistep deis now") raise NotImplementedError("only support log-rho multistep deis now")
...@@ -324,9 +396,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -324,9 +396,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def deis_first_order_update( def deis_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the first-order DEIS (equivalent to DDIM). One step for the first-order DEIS (equivalent to DDIM).
...@@ -345,9 +417,33 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -345,9 +417,33 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s h = lambda_t - lambda_s
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
...@@ -358,9 +454,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -358,9 +454,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_second_order_update( def multistep_deis_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the second-order multistep DEIS. One step for the second-order multistep DEIS.
...@@ -368,10 +464,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -368,10 +464,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -379,10 +471,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -379,10 +471,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2]
alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1]
sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1]
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
...@@ -403,9 +523,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -403,9 +523,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_deis_third_order_update( def multistep_deis_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the third-order multistep DEIS. One step for the third-order multistep DEIS.
...@@ -413,10 +533,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -413,10 +533,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
...@@ -424,15 +540,47 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -424,15 +540,47 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2]
sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2]
rho_t, rho_s0, rho_s1, rho_s2 = ( rho_t, rho_s0, rho_s1, rho_s2 = (
sigma_t / alpha_t, sigma_t / alpha_t,
sigma_s0 / alpha_s0, sigma_s0 / alpha_s0,
sigma_s1 / alpha_s1, sigma_s1 / alpha_s1,
simga_s2 / alpha_s2, sigma_s2 / alpha_s2,
) )
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
...@@ -460,6 +608,25 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -460,6 +608,25 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError("only support log-rho multistep deis now") raise NotImplementedError("only support log-rho multistep deis now")
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -492,42 +659,34 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -492,42 +659,34 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = ( lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
) )
lower_order_second = ( lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
) )
model_output = self.convert_model_output(model_output, timestep, sample) model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1): for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) prev_sample = self.deis_first_order_update(model_output, sample=sample)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_deis_second_order_update(self.model_outputs, sample=sample)
prev_sample = self.multistep_deis_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
else: else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_deis_third_order_update(self.model_outputs, sample=sample)
prev_sample = self.multistep_deis_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order: if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1 self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
...@@ -548,28 +707,30 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -548,28 +707,30 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sigma = sigmas[step_indices].flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sigma.shape) < len(original_samples.shape):
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -203,6 +204,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -203,6 +204,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
""" """
...@@ -242,19 +251,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -242,19 +251,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
self.sigmas = torch.from_numpy(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
# when num_inference_steps == num_train_timesteps, we can end up with sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -264,6 +273,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -264,6 +273,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
""" """
...@@ -323,6 +335,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -323,6 +335,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape) t = t.reshape(sigma.shape)
return t return t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
...@@ -338,7 +356,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -338,7 +356,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas return sigmas
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
...@@ -355,8 +377,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -355,8 +377,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -364,6 +384,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -364,6 +384,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
...@@ -371,12 +403,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -371,12 +403,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]: if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3] model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
...@@ -398,10 +432,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -398,10 +432,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
epsilon = model_output epsilon = model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample epsilon = alpha_t * model_output + sigma_t * sample
else: else:
raise ValueError( raise ValueError(
...@@ -410,7 +446,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -410,7 +446,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
if self.config.thresholding: if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred) x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t epsilon = (sample - alpha_t * x0_pred) / sigma_t
...@@ -420,10 +457,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -420,10 +457,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
...@@ -431,10 +468,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -431,10 +468,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -442,9 +475,33 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -442,9 +475,33 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
...@@ -469,10 +526,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -469,10 +526,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the second-order multistep DPMSolver. One step for the second-order multistep DPMSolver.
...@@ -480,10 +537,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -480,10 +537,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -491,11 +544,43 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -491,11 +544,43 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1) D0, D1 = m0, (1.0 / r0) * (m0 - m1)
...@@ -564,9 +649,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -564,9 +649,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_dpm_solver_third_order_update( def multistep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the third-order multistep DPMSolver. One step for the third-order multistep DPMSolver.
...@@ -574,10 +659,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -574,10 +659,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
...@@ -585,16 +666,47 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -585,16 +666,47 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
self.lambda_t[t], if sample is None:
self.lambda_t[s0], if len(args) > 2:
self.lambda_t[s1], sample = args[2]
self.lambda_t[s2], else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
) )
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h r0, r1 = h_0 / h, h_1 / h
D0 = m0 D0 = m0
...@@ -619,6 +731,25 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -619,6 +731,25 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
return x_t return x_t
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -654,22 +785,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -654,22 +785,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
lower_order_final = ( lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
) )
lower_order_second = ( lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
) )
model_output = self.convert_model_output(model_output, timestep, sample) model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1): for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
...@@ -682,23 +808,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -682,23 +808,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noise = None noise = None
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update( prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
else: else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order: if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1 self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
...@@ -719,28 +840,30 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -719,28 +840,30 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sigma = sigmas[step_indices].flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sigma.shape) < len(original_samples.shape):
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -203,8 +204,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -203,8 +204,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self._step_index = None
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -244,11 +253,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -244,11 +253,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
) )
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = timesteps.copy().astype(np.int64) timesteps = timesteps.copy().astype(np.int64)
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
(1 - self.alphas_cumprod[self.noisiest_timestep]) / self.alphas_cumprod[self.noisiest_timestep]
) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_max]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
...@@ -266,6 +283,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -266,6 +283,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
""" """
...@@ -325,6 +345,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -325,6 +345,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape) t = t.reshape(sigma.shape)
return t return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
...@@ -341,7 +368,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -341,7 +368,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
...@@ -358,8 +389,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -358,8 +389,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -367,6 +396,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -367,6 +396,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
...@@ -374,12 +415,14 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -374,12 +415,14 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned", "learned_range"]: if self.config.variance_type in ["learned", "learned_range"]:
model_output = model_output[:, :3] model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
...@@ -401,10 +444,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -401,10 +444,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
else: else:
epsilon = model_output epsilon = model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample epsilon = alpha_t * model_output + sigma_t * sample
else: else:
raise ValueError( raise ValueError(
...@@ -413,20 +458,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -413,20 +458,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
) )
if self.config.thresholding: if self.config.thresholding:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = (sample - sigma_t * epsilon) / alpha_t
x0_pred = self._threshold_sample(x0_pred) x0_pred = self._threshold_sample(x0_pred)
epsilon = (sample - alpha_t * x0_pred) / sigma_t epsilon = (sample - alpha_t * x0_pred) / sigma_t
return epsilon return epsilon
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
...@@ -434,10 +481,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -434,10 +481,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -445,27 +488,62 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -445,27 +488,62 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif "sde" in self.config.algorithm_type: elif self.config.algorithm_type == "sde-dpmsolver++":
raise NotImplementedError( assert noise is not None
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
) )
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
def multistep_dpm_solver_second_order_update( def multistep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None, noise: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the second-order multistep DPMSolver. One step for the second-order multistep DPMSolver.
...@@ -473,10 +551,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -473,10 +551,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -484,11 +558,43 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -484,11 +558,43 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1) D0, D1 = m0, (1.0 / r0) * (m0 - m1)
...@@ -520,19 +626,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -520,19 +626,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
- (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
) )
elif "sde" in self.config.algorithm_type: elif self.config.algorithm_type == "sde-dpmsolver++":
raise NotImplementedError( assert noise is not None
f"Inversion step is not yet implemented for algorithm type {self.config.algorithm_type}." if self.config.solver_type == "midpoint":
) x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- (sigma_t * (torch.exp(h) - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
elif self.config.solver_type == "heun":
x_t = (
(alpha_t / alpha_s0) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0
- 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
def multistep_dpm_solver_third_order_update( def multistep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the third-order multistep DPMSolver. One step for the third-order multistep DPMSolver.
...@@ -540,10 +674,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -540,10 +674,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output_list (`List[torch.FloatTensor]`): model_output_list (`List[torch.FloatTensor]`):
The direct outputs from learned diffusion model at current and latter timesteps. The direct outputs from learned diffusion model at current and latter timesteps.
timestep (`int`):
The current and latter discrete timestep in the diffusion chain.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by diffusion process. A current instance of a sample created by diffusion process.
...@@ -551,16 +681,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -551,16 +681,47 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
self.lambda_t[t], if sample is None:
self.lambda_t[s0], if len(args) > 2:
self.lambda_t[s1], sample = args[2]
self.lambda_t[s2], else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
) )
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h r0, r1 = h_0 / h, h_1 / h
D0 = m0 D0 = m0
...@@ -585,6 +746,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -585,6 +746,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
) )
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -604,6 +786,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -604,6 +786,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
...@@ -618,24 +802,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -618,24 +802,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = (
self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
)
lower_order_final = ( lower_order_final = (
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
) )
lower_order_second = ( lower_order_second = (
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
) )
model_output = self.convert_model_output(model_output, timestep, sample) model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1): for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
...@@ -648,23 +825,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -648,23 +825,18 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
noise = None noise = None
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update( prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
model_output, timestep, prev_timestep, sample, noise=noise
)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
)
else: else:
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, timestep_list, prev_timestep, sample
)
if self.lower_order_nums < self.config.solver_order: if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1 self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
...@@ -686,28 +858,30 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -686,28 +858,30 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sigma = sigmas[step_indices].flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sigma.shape) < len(original_samples.shape):
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging from ..utils import deprecate, logging
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
...@@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.sample = None self.sample = None
self.order_list = self.get_order_list(num_train_timesteps) self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
def get_order_list(self, num_inference_steps: int) -> List[int]: def get_order_list(self, num_inference_steps: int) -> List[int]:
""" """
...@@ -232,6 +233,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -232,6 +233,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
orders = [1] * steps orders = [1] * steps
return orders return orders
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -256,11 +264,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -256,11 +264,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.model_outputs = [None] * self.config.solver_order self.model_outputs = [None] * self.config.solver_order
...@@ -274,6 +287,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -274,6 +287,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.order_list = self.get_order_list(num_inference_steps) self.order_list = self.get_order_list(num_inference_steps)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
""" """
...@@ -333,6 +349,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -333,6 +349,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape) t = t.reshape(sigma.shape)
return t return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
...@@ -348,7 +371,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -348,7 +371,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
return sigmas return sigmas
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
...@@ -365,8 +392,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -365,8 +392,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.FloatTensor`): model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model. The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`): sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
...@@ -374,18 +399,32 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -374,18 +399,32 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model. # DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
# DPM-Solver and DPM-Solver++ only need the "mean" output. # DPM-Solver and DPM-Solver++ only need the "mean" output.
if self.config.variance_type in ["learned_range"]: if self.config.variance_type in ["learned_range"]:
model_output = model_output[:, :3] model_output = model_output[:, :3]
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
...@@ -405,11 +444,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -405,11 +444,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
model_output = model_output[:, :3] model_output = model_output[:, :3]
return model_output return model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon return epsilon
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
epsilon = alpha_t * model_output + sigma_t * sample epsilon = alpha_t * model_output + sigma_t * sample
return epsilon return epsilon
else: else:
...@@ -421,9 +462,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -421,9 +462,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the first-order DPMSolver (equivalent to DDIM). One step for the first-order DPMSolver (equivalent to DDIM).
...@@ -442,9 +483,31 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -442,9 +483,31 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++": if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
...@@ -455,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -455,9 +518,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_second_order_update( def singlestep_dpm_solver_second_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
...@@ -477,11 +540,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -477,11 +540,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1]
sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1]
h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1
r0 = h_0 / h r0 = h_0 / h
D0, D1 = m1, (1.0 / r0) * (m0 - m1) D0, D1 = m1, (1.0 / r0) * (m0 - m1)
...@@ -518,9 +612,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -518,9 +612,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_third_order_update( def singlestep_dpm_solver_third_order_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
...@@ -540,16 +634,47 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -540,16 +634,47 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3]
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
self.lambda_t[t], if sample is None:
self.lambda_t[s0], if len(args) > 2:
self.lambda_t[s1], sample = args[2]
self.lambda_t[s2], else:
raise ValueError(" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
self.sigmas[self.step_index - 2],
) )
alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2]
sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h r0, r1 = h_0 / h, h_1 / h
D0 = m2 D0 = m2
...@@ -591,10 +716,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -591,10 +716,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def singlestep_dpm_solver_update( def singlestep_dpm_solver_update(
self, self,
model_output_list: List[torch.FloatTensor], model_output_list: List[torch.FloatTensor],
timestep_list: List[int], *args,
prev_timestep: int, sample: torch.FloatTensor = None,
sample: torch.FloatTensor, order: int = None,
order: int, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the singlestep DPMSolver. One step for the singlestep DPMSolver.
...@@ -615,19 +740,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -615,19 +740,60 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(" missing`sample` as a required keyward argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(" missing `order` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if order == 1: if order == 1:
return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
elif order == 2: elif order == 2:
return self.singlestep_dpm_solver_second_order_update( return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
model_output_list, timestep_list, prev_timestep, sample
)
elif order == 3: elif order == 3:
return self.singlestep_dpm_solver_third_order_update( return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
model_output_list, timestep_list, prev_timestep, sample
)
else: else:
raise ValueError(f"Order must be 1, 2, 3, got {order}") raise ValueError(f"Order must be 1, 2, 3, got {order}")
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -660,21 +826,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -660,21 +826,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
model_output = self.convert_model_output(model_output, timestep, sample) model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1): for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output self.model_outputs[-1] = model_output
order = self.order_list[step_index] order = self.order_list[self.step_index]
# For img2img denoising might start with order>1 which is not possible # For img2img denoising might start with order>1 which is not possible
# In this case make sure that the first two steps are both order=1 # In this case make sure that the first two steps are both order=1
...@@ -685,10 +845,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -685,10 +845,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if order == 1: if order == 1:
self.sample = sample self.sample = sample
timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
prev_sample = self.singlestep_dpm_solver_update(
self.model_outputs, timestep_list, prev_timestep, self.sample, order # upon completion increase step index by one
) self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
...@@ -710,28 +870,30 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -710,28 +870,30 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sigma = sigmas[step_indices].flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sigma.shape) < len(original_samples.shape):
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -22,10 +22,16 @@ import numpy as np ...@@ -22,10 +22,16 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
...@@ -38,19 +44,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -38,19 +44,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
num_diffusion_timesteps (`int`): the number of betas to produce. num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns: Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
""" """
if alpha_transform_type == "cosine":
def alpha_bar(time_step): def alpha_bar_fn(t):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = [] betas = []
for i in range(num_diffusion_timesteps): for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32) return torch.tensor(betas, dtype=torch.float32)
...@@ -181,6 +198,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -181,6 +198,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.disable_corrector = disable_corrector self.disable_corrector = disable_corrector
self.solver_p = solver_p self.solver_p = solver_p
self.last_sample = None self.last_sample = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
...@@ -220,17 +245,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -220,17 +245,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
self.sigmas = torch.from_numpy(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
# when num_inference_steps == num_train_timesteps, we can end up with sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -243,6 +267,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,6 +267,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.solver_p: if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device) self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
""" """
...@@ -302,6 +329,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -302,6 +329,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape) t = t.reshape(sigma.shape)
return t return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
...@@ -317,7 +351,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -317,7 +351,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas return sigmas
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self,
model_output: torch.FloatTensor,
*args,
sample: torch.FloatTensor = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r""" r"""
Convert the model output to the corresponding type the UniPC algorithm needs. Convert the model output to the corresponding type the UniPC algorithm needs.
...@@ -334,14 +372,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -334,14 +372,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The converted model output. The converted model output.
""" """
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0: if self.predict_x0:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
...@@ -357,11 +409,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -357,11 +409,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
return model_output return model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon return epsilon
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample epsilon = alpha_t * model_output + sigma_t * sample
return epsilon return epsilon
else: else:
...@@ -373,9 +423,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -373,9 +423,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_p_bh_update( def multistep_uni_p_bh_update(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
prev_timestep: int, *args,
sample: torch.FloatTensor, sample: torch.FloatTensor = None,
order: int, order: int = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
...@@ -394,10 +445,26 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -394,10 +445,26 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The sample tensor at the previous timestep. The sample tensor at the previous timestep.
""" """
timestep_list = self.timestep_list prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError(" missing `order` as a required keyward argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs model_output_list = self.model_outputs
s0, t = self.timestep_list[-1], prev_timestep s0 = self.timestep_list[-1]
m0 = model_output_list[-1] m0 = model_output_list[-1]
x = sample x = sample
...@@ -405,9 +472,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -405,9 +472,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = self.solver_p.step(model_output, s0, x).prev_sample x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t return x_t
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0 h = lambda_t - lambda_s0
device = sample.device device = sample.device
...@@ -415,9 +485,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -415,9 +485,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rks = [] rks = []
D1s = [] D1s = []
for i in range(1, order): for i in range(1, order):
si = timestep_list[-(i + 1)] si = self.step_index - i
mi = model_output_list[-(i + 1)] mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h rk = (lambda_si - lambda_s0) / h
rks.append(rk) rks.append(rk)
D1s.append((mi - m0) / rk) D1s.append((mi - m0) / rk)
...@@ -481,10 +552,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -481,10 +552,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def multistep_uni_c_bh_update( def multistep_uni_c_bh_update(
self, self,
this_model_output: torch.FloatTensor, this_model_output: torch.FloatTensor,
this_timestep: int, *args,
last_sample: torch.FloatTensor, last_sample: torch.FloatTensor = None,
this_sample: torch.FloatTensor, this_sample: torch.FloatTensor = None,
order: int, order: int = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
""" """
One step for the UniC (B(h) version). One step for the UniC (B(h) version).
...@@ -505,18 +577,42 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -505,18 +577,42 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
The corrected sample tensor at the current timestep. The corrected sample tensor at the current timestep.
""" """
timestep_list = self.timestep_list this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError(" missing`last_sample` as a required keyward argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError(" missing`this_sample` as a required keyward argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(" missing`order` as a required keyward argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs model_output_list = self.model_outputs
s0, t = timestep_list[-1], this_timestep
m0 = model_output_list[-1] m0 = model_output_list[-1]
x = last_sample x = last_sample
x_t = this_sample x_t = this_sample
model_t = this_model_output model_t = this_model_output
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0 h = lambda_t - lambda_s0
device = this_sample.device device = this_sample.device
...@@ -524,9 +620,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -524,9 +620,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rks = [] rks = []
D1s = [] D1s = []
for i in range(1, order): for i in range(1, order):
si = timestep_list[-(i + 1)] si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)] mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h rk = (lambda_si - lambda_s0) / h
rks.append(rk) rks.append(rk)
D1s.append((mi - m0) / rk) D1s.append((mi - m0) / rk)
...@@ -589,6 +686,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -589,6 +686,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = x_t.to(x.dtype) x_t = x_t.to(x.dtype)
return x_t return x_t
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index = step_index
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -616,37 +732,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -616,37 +732,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
if self.num_inference_steps is None: if self.num_inference_steps is None:
raise ValueError( raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero()
if len(step_index) == 0:
step_index = len(self.timesteps) - 1
else:
step_index = step_index.item()
use_corrector = ( use_corrector = (
step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
) )
model_output_convert = self.convert_model_output(model_output, timestep, sample) model_output_convert = self.convert_model_output(model_output, sample=sample)
if use_corrector: if use_corrector:
sample = self.multistep_uni_c_bh_update( sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert, this_model_output=model_output_convert,
this_timestep=timestep,
last_sample=self.last_sample, last_sample=self.last_sample,
this_sample=sample, this_sample=sample,
order=self.this_order, order=self.this_order,
) )
# now prepare to run the predictor
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
for i in range(self.config.solver_order - 1): for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1] self.timestep_list[i] = self.timestep_list[i + 1]
...@@ -655,7 +761,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -655,7 +761,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timestep_list[-1] = timestep self.timestep_list[-1] = timestep
if self.config.lower_order_final: if self.config.lower_order_final:
this_order = min(self.config.solver_order, len(self.timesteps) - step_index) this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
else: else:
this_order = self.config.solver_order this_order = self.config.solver_order
...@@ -665,7 +771,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -665,7 +771,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.last_sample = sample self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update( prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used model_output=model_output, # pass the original non-converted model output, in case solver-p is used
prev_timestep=prev_timestep,
sample=sample, sample=sample,
order=self.this_order, order=self.this_order,
) )
...@@ -673,6 +778,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -673,6 +778,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.lower_order_nums < self.config.solver_order: if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1 self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
...@@ -693,28 +801,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -693,28 +801,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.FloatTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sigma = sigmas[step_indices].flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sigma.shape) < len(original_samples.shape):
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -51,6 +51,7 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): ...@@ -51,6 +51,7 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
output, new_output = sample, sample output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1): for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
......
...@@ -59,6 +59,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -59,6 +59,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
output, new_output = sample, sample output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1): for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = new_scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
...@@ -91,6 +92,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -91,6 +92,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
# copy over dummy past residual (must be after setting timesteps) # copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
time_step = new_scheduler.timesteps[time_step]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
...@@ -264,10 +266,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -264,10 +266,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_unique_timesteps(self, **config): def test_duplicated_timesteps(self, **config):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(scheduler.config.num_train_timesteps) scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps assert len(scheduler.timesteps) == scheduler.num_inference_steps
...@@ -54,6 +54,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -54,6 +54,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
output, new_output = sample, sample output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1): for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
...@@ -222,7 +223,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -222,7 +223,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 1.7833) < 1e-3 assert abs(result_mean.item() - 1.7833) < 2e-3
def test_switch(self): def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results # make sure that iterating over schedulers with same config names gives same results
......
...@@ -58,6 +58,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -58,6 +58,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
output, new_output = sample, sample output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1): for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
...@@ -248,3 +249,33 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -248,3 +249,33 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
time_step_0 = scheduler.timesteps[0]
time_step_1 = scheduler.timesteps[1]
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
...@@ -52,6 +52,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -52,6 +52,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
output, new_output = sample, sample output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1): for t in range(time_step, time_step + scheduler.config.solver_order + 1):
t = scheduler.timesteps[t]
output = scheduler.step(residual, t, output, **kwargs).prev_sample output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
...@@ -241,11 +242,3 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -241,11 +242,3 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment