Unverified Commit 524535b5 authored by Nipun Jindal's avatar Nipun Jindal Committed by GitHub
Browse files

[2064]: Add Karras to DPMSolverMultistepScheduler (#3001)



* [2737]: Add Karras DPMSolverMultistepScheduler

* [2737]: Add Karras DPMSolverMultistepScheduler

* Add test

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix: repo consistency.

* remove Copied from statement from the set_timestep method.

* fix: test

* Empty commit.
Co-authored-by: default avatarnjindal <njindal@adobe.com>

---------
Co-authored-by: default avatarnjindal <njindal@adobe.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7b2407f4
...@@ -171,7 +171,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,7 +171,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`): lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -136,6 +139,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -136,6 +139,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -181,6 +185,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -181,6 +185,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self.use_karras_sigmas = use_karras_sigmas
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
...@@ -199,6 +204,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -199,6 +204,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
.astype(np.int64) .astype(np.int64)
) )
if self.use_karras_sigmas:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = np.flip(timesteps).copy().astype(np.int64)
# when num_inference_steps == num_train_timesteps, we can end up with # when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps. # duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True) _, unique_indices = np.unique(timesteps, return_index=True)
...@@ -248,6 +260,44 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -248,6 +260,44 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
......
...@@ -206,7 +206,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -206,7 +206,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
) )
if self.use_karras_sigmas: if self.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
...@@ -241,14 +241,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,14 +241,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return t return t
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item() sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item() sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, self.num_inference_steps) ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho) min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
......
...@@ -209,6 +209,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -209,6 +209,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2251) < 1e-3 assert abs(result_mean.item() - 0.2251) < 1e-3
def test_full_loop_with_karras_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2096) < 1e-3
def test_switch(self): def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results # make sure that iterating over schedulers with same config names gives same results
# for defaults # for defaults
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment