"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "f3b7fb50e422b78fbbba2addcf6ea8f9e86f893a"
Unverified Commit 48207d66 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Scheduler] fix: EDM schedulers when using the exp sigma schedule. (#8385)

* fix: euledm when using the exp sigma schedule.

* fix-copies

* remove print.

* reduce friction

* yiyi's suggestioms
parent 2f6f426f
...@@ -243,13 +243,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,13 +243,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
ramp = np.linspace(0, 1, self.num_inference_steps) ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras": if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp) sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential": elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp) sigmas = self._compute_exponential_sigmas(ramp)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas) self.timesteps = self.precondition_noise(sigmas)
if self.config.final_sigmas_type == "sigma_min": if self.config.final_sigmas_type == "sigma_min":
...@@ -283,7 +283,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,7 +283,6 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
min_inv_rho = sigma_min ** (1 / rho) min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
......
...@@ -16,7 +16,6 @@ import math ...@@ -16,7 +16,6 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
...@@ -210,13 +209,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -210,13 +209,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
ramp = np.linspace(0, 1, self.num_inference_steps) ramp = torch.linspace(0, 1, self.num_inference_steps)
if self.config.sigma_schedule == "karras": if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp) sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential": elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp) sigmas = self._compute_exponential_sigmas(ramp)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas) self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
...@@ -234,7 +233,6 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -234,7 +233,6 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
min_inv_rho = sigma_min ** (1 / rho) min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment