Unverified Commit b6447fa8 authored by Eyal Mazuz's avatar Eyal Mazuz Committed by GitHub
Browse files

Allow DDPM scheduler to use model's predicated variance (#132)



* Extented the ability of ddpm scheduler
to utilize model that also predict the variance.

* Update src/diffusers/schedulers/scheduling_ddpm.py
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>
parent b6cadcef
...@@ -82,6 +82,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -82,6 +82,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
self.variance_type = variance_type
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
...@@ -90,7 +92,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -90,7 +92,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy() )[::-1].copy()
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
def _get_variance(self, t, variance_type=None): def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
...@@ -113,6 +115,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,6 +115,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif variance_type == "fixed_large_log": elif variance_type == "fixed_large_log":
# Glide max_log # Glide max_log
variance = self.log(self.betas[t]) variance = self.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance return variance
...@@ -125,6 +134,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -125,6 +134,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
generator=None, generator=None,
): ):
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas # 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
...@@ -155,7 +170,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -155,7 +170,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = 0 variance = 0
if t > 0: if t > 0:
noise = self.randn_like(model_output, generator=generator) noise = self.randn_like(model_output, generator=generator)
variance = (self._get_variance(t) ** 0.5) * noise variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment