Unverified Commit 48207d66 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Scheduler] fix: EDM schedulers when using the exp sigma schedule. (#8385)

* fix: euledm when using the exp sigma schedule.

* fix-copies

* remove print.

* reduce friction

* yiyi's suggestioms
parent 2f6f426f
......@@ -243,13 +243,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps
ramp = np.linspace(0, 1, self.num_inference_steps)
ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
if self.config.final_sigmas_type == "sigma_min":
......@@ -283,7 +283,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
......
......@@ -16,7 +16,6 @@ import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
......@@ -210,13 +209,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
ramp = np.linspace(0, 1, self.num_inference_steps)
ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
......@@ -234,7 +233,6 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment