"...sox/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "e65e472698b8987398816ee83ea9dbcba1f713f3"
Unverified Commit 11f527ac authored by Youssef Adarrab's avatar Youssef Adarrab Committed by GitHub
Browse files

Add `Karras sigmas` to HeunDiscreteScheduler (#3160)



* Add karras pattern to discrete heun scheduler

* Add integration test

* Fix failing CI on pytorch test on M1 (mps)

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c04e585
......@@ -75,7 +75,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
https://imagen.research.google/video/paper.pdf).
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
......@@ -90,6 +94,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
......@@ -111,6 +116,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self.use_karras_sigmas = use_karras_sigmas
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
......@@ -165,7 +171,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
......@@ -186,6 +198,44 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.prev_derivative = None
self.dt = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
@property
def state_in_first_order(self):
return self.dt is None
......
......@@ -129,3 +129,28 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
# CUDA
assert abs(result_sum.item() - 0.1233) < 1e-2
assert abs(result_mean.item() - 0.0002) < 1e-3
def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 0.00015) < 1e-2
assert abs(result_mean.item() - 1.9869554535034695e-07) < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment