Unverified Commit 464374fb authored by hlky's avatar hlky Committed by GitHub
Browse files

EDMEulerScheduler accept sigmas, add final_sigmas_type (#10734)

parent d43ce14e
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
Video](https://imagen.research.google/video/paper.pdf) paper). Video](https://imagen.research.google/video/paper.pdf) paper).
rho (`float`, *optional*, defaults to 7.0): rho (`float`, *optional*, defaults to 7.0):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
""" """
_compatibles = [] _compatibles = []
...@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
rho: float = 7.0, rho: float = 7.0,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
): ):
if sigma_schedule not in ["karras", "exponential"]: if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
...@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
ramp = torch.linspace(0, 1, num_train_timesteps) sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
if sigma_schedule == "karras": if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp) sigmas = self._compute_karras_sigmas(sigmas)
elif sigma_schedule == "exponential": elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp) sigmas = self._compute_exponential_sigmas(sigmas)
self.timesteps = self.precondition_noise(sigmas) self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self.is_scale_input_called = False self.is_scale_input_called = False
...@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True self.is_scale_input_called = True
return sample return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model. The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
ramp = torch.linspace(0, 1, self.num_inference_steps) if sigmas is None:
sigmas = torch.linspace(0, 1, self.num_inference_steps)
elif isinstance(sigmas, float):
sigmas = torch.tensor(sigmas, dtype=torch.float32)
else:
sigmas = sigmas
if self.config.sigma_schedule == "karras": if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp) sigmas = self._compute_karras_sigmas(sigmas)
elif self.config.sigma_schedule == "exponential": elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp) sigmas = self._compute_exponential_sigmas(sigmas)
sigmas = sigmas.to(dtype=torch.float32, device=device) sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas) self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self._step_index = None self._step_index = None
self._begin_index = None self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment