Commit 2852c805 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

update readme

parent 97226d97
...@@ -45,28 +45,44 @@ image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.re ...@@ -45,28 +45,44 @@ image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.re
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # 1. predict noise residual
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) with torch.no_grad():
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) pred_noise_t = self.unet(image, t)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) # 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
# ii) predict noise residual alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
with torch.no_grad(): beta_prod_t = 1 - alpha_prod_t
noise_residual = model(image, t) beta_prod_t_prev = 1 - alpha_prod_t_prev
# iii) compute predicted image from residual # 3. compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # First: compute predicted original image from predicted noise also called
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_mean = torch.clamp(pred_mean, -1, 1) pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
prev_image = clipped_coeff * pred_mean + image_coeff * image
# Second: Clip "predicted x_0"
# iv) sample variance pred_original_image = torch.clamp(pred_original_image, -1, 1)
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# Third: Compute coefficients for pred_original_image x_0 and current image x_t
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
sampled_prev_image = prev_image + prev_variance pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
image = sampled_prev_image current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# Fourth: Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0:
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
prev_image = pred_prev_image + variance * noise
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
# process image to PIL # process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = image.cpu().permute(0, 2, 3, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment