Unverified Commit 0f1c2466 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix heun scheduler (#1512)

parent e65b71ab
...@@ -186,10 +186,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,10 +186,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output sigma_input = sigma_hat if self.state_in_first_order else sigma_next
pred_original_sample = sample - sigma_input * model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip sigma_input = sigma_hat if self.state_in_first_order else sigma_next
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1)
)
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
...@@ -207,7 +210,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -207,7 +210,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = sample self.sample = sample
else: else:
# 2. 2nd order / Heun's method # 2. 2nd order / Heun's method
derivative = (sample - pred_original_sample) / sigma_hat derivative = (sample - pred_original_sample) / sigma_next
derivative = (self.prev_derivative + derivative) / 2 derivative = (self.prev_derivative + derivative) / 2
# 3. Retrieve 1st order derivative # 3. Retrieve 1st order derivative
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment