Unverified Commit b6447fa8 authored by Eyal Mazuz's avatar Eyal Mazuz Committed by GitHub
Browse files

Allow DDPM scheduler to use model's predicated variance (#132)



* Extented the ability of ddpm scheduler
to utilize model that also predict the variance.

* Update src/diffusers/schedulers/scheduling_ddpm.py
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>
parent b6cadcef
......@@ -82,6 +82,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.variance_type = variance_type
def set_timesteps(self, num_inference_steps):
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
......@@ -90,7 +92,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
def _get_variance(self, t, variance_type=None):
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
......@@ -113,6 +115,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif variance_type == "fixed_large_log":
# Glide max_log
variance = self.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
......@@ -125,6 +134,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
generator=None,
):
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
......@@ -155,7 +170,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = 0
if t > 0:
noise = self.randn_like(model_output, generator=generator)
variance = (self._get_variance(t) ** 0.5) * noise
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment