Unverified Commit 1122c707 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Update README.md

parent 2852c805
...@@ -50,8 +50,8 @@ for t in reversed(range(len(scheduler))): ...@@ -50,8 +50,8 @@ for t in reversed(range(len(scheduler))):
pred_noise_t = self.unet(image, t) pred_noise_t = self.unet(image, t)
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t) alpha_prod_t = scheduler.get_alpha_prod(t)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1) alpha_prod_t_prev = scheduler.get_alpha_prod(t - 1)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -65,8 +65,8 @@ for t in reversed(range(len(scheduler))): ...@@ -65,8 +65,8 @@ for t in reversed(range(len(scheduler))):
# Third: Compute coefficients for pred_original_image x_0 and current image x_t # Third: Compute coefficients for pred_original_image x_0 and current image x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * scheduler.get_beta(t)) / beta_prod_t
current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t current_image_coeff = scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# Fourth: Compute predicted previous image µ_t # Fourth: Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
...@@ -76,7 +76,7 @@ for t in reversed(range(len(scheduler))): ...@@ -76,7 +76,7 @@ for t in reversed(range(len(scheduler))):
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0: if t > 0:
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt() variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = scheduler.sample_noise(image.shape, device=image.device, generator=generator)
prev_image = pred_prev_image + variance * noise prev_image = pred_prev_image + variance * noise
else: else:
prev_image = pred_prev_image prev_image = pred_prev_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment