"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f8d4a1e2833191328a6572890ae15bca793eef7c"
Unverified Commit c812d97d authored by Peter Lin's avatar Peter Lin Committed by GitHub
Browse files

Improve ddim scheduler and fix bug when prediction type is "sample" (#2094)



Improve ddim scheduler
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c5f6c538
...@@ -301,12 +301,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -301,12 +301,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
...@@ -328,17 +329,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -328,17 +329,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output: if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide # the pred_epsilon is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0: if eta > 0:
device = model_output.device
if variance_noise is not None and generator is not None: if variance_noise is not None and generator is not None:
raise ValueError( raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or" "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
...@@ -347,7 +347,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -347,7 +347,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if variance_noise is None: if variance_noise is None:
variance_noise = randn_tensor( variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
) )
variance = std_dev_t * variance_noise variance = std_dev_t * variance_noise
......
...@@ -254,12 +254,13 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -254,12 +254,13 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
...@@ -272,7 +273,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -272,7 +273,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment