Unverified Commit aab6de22 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Improve LCMScheduler (#5681)



* Refactor LCMScheduler.step such that prev_sample == denoised at the last timestep in the schedule.

* Make timestep scaling when calculating boundary conditions configurable.

* Reparameterize timestep_scaling to be a multiplicative rather than division scaling.

* make style

* fix dtype conversion

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1dc231d1
...@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`): timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
timestep_scaling (`float`, defaults to 10.0):
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`): rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to dark samples instead of limiting it to samples with medium brightness. Loosely related to
...@@ -208,6 +212,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -208,6 +212,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
timestep_spacing: str = "leading", timestep_spacing: str = "leading",
timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False, rescale_betas_zero_snr: bool = False,
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -380,12 +385,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -380,12 +385,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None self._step_index = None
def get_scalings_for_boundary_condition_discrete(self, t): def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5 self.sigma_data = 0.5 # Default: 0.5
scaled_timestep = timestep * self.config.timestep_scaling
# By dividing 0.1: This is almost a delta function at t=0. c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2) c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
def step( def step(
...@@ -466,9 +471,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -466,9 +471,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
denoised = c_out * predicted_original_sample + c_skip * sample denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used for one-step sampling. # Noise is not used on the final timestep of the timestep schedule.
if len(self.timesteps) > 1: # This also means that noise is not used for one-step sampling.
noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device) if self.step_index != self.num_inference_steps - 1:
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else: else:
prev_sample = denoised prev_sample = denoised
......
...@@ -230,7 +230,7 @@ class LCMSchedulerTest(SchedulerCommonTest): ...@@ -230,7 +230,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean # TODO: get expected sum and mean
assert abs(result_sum.item() - 18.7097) < 1e-2 assert abs(result_sum.item() - 18.7097) < 1e-3
assert abs(result_mean.item() - 0.0244) < 1e-3 assert abs(result_mean.item() - 0.0244) < 1e-3
def test_full_loop_multistep(self): def test_full_loop_multistep(self):
...@@ -240,5 +240,5 @@ class LCMSchedulerTest(SchedulerCommonTest): ...@@ -240,5 +240,5 @@ class LCMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean # TODO: get expected sum and mean
assert abs(result_sum.item() - 280.5618) < 1e-2 assert abs(result_sum.item() - 197.7616) < 1e-3
assert abs(result_mean.item() - 0.3653) < 1e-3 assert abs(result_mean.item() - 0.2575) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment