"...composable_kernel.git" did not exist on "18d2bb1b981d7b851a9f8cef98d8fae576d8a2c5"
Unverified Commit c812d97d authored by Peter Lin's avatar Peter Lin Committed by GitHub
Browse files

Improve ddim scheduler and fix bug when prediction type is "sample" (#2094)



Improve ddim scheduler
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c5f6c538
......@@ -301,12 +301,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
......@@ -328,17 +329,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
device = model_output.device
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
......@@ -347,7 +347,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
variance = std_dev_t * variance_noise
......
......@@ -254,12 +254,13 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
......@@ -272,7 +273,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5)
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment