Unverified Commit 30f6f441 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

add v prediction (#1386)



* add v prediction

* adat euler for v pred

* velocity -> v_prediction

* simplify

* fix naming

* Update src/diffusers/schedulers/scheduling_euler_discrete.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* style
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 9f476388
...@@ -122,6 +122,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -122,6 +122,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
...@@ -138,6 +139,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -138,6 +139,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
...@@ -258,7 +261,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -258,7 +261,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
......
...@@ -78,6 +78,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,6 +78,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
prediction_type: str = "epsilon",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.from_numpy(trained_betas)
...@@ -91,6 +92,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -91,6 +92,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
...@@ -229,7 +232,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -229,7 +232,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output pred_original_sample = sample - sigma_hat * model_output
elif self.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative # 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat derivative = (sample - pred_original_sample) / sigma_hat
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment