Commit 999d3856 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make code cleaner

parent 27039cd3
...@@ -42,7 +42,7 @@ class DDIM(DiffusionPipeline): ...@@ -42,7 +42,7 @@ class DDIM(DiffusionPipeline):
generator=generator, generator=generator,
) )
# See formulas (9), (10) and (7) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper> # Notation (<variable name> -> <name in paper>
...@@ -68,7 +68,6 @@ class DDIM(DiffusionPipeline): ...@@ -68,7 +68,6 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev = (1 - alpha_prod_t_prev) beta_prod_t_prev = (1 - alpha_prod_t_prev)
# 4. Compute predicted previous image from predicted noise # 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called # First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
......
...@@ -41,33 +41,45 @@ class DDPM(DiffusionPipeline): ...@@ -41,33 +41,45 @@ class DDPM(DiffusionPipeline):
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
noise_residual = self.unet(image, t) pred_noise_t = self.unet(image, t)
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t) alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1) alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = (1 - alpha_prod_t)
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = (1 - alpha_prod_t_prev)
# 3. compute predicted image from residual # 3. compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # First: compute predicted original image from predicted noise also called
# First: Compute inner formula # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_mean = (1 / alpha_prod_t.sqrt()) * (image - beta_prod_t.sqrt() * noise_residual) pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# Second: Clip
pred_mean = torch.clamp(pred_mean, -1, 1) # Second: Clip "predicted x_0"
# Third: Compute outer coefficients pred_original_image = torch.clamp(pred_original_image, -1, 1)
pred_mean_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
image_coeff = (beta_prod_t_prev * self.noise_scheduler.get_alpha(t).sqrt()) / beta_prod_t # Third: Compute coefficients for pred_original_image x_0 and current image x_t
# Fourth: Compute outer formula # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
prev_image = pred_mean_coeff * pred_mean + image_coeff * image pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# 4. sample variance # Fourth: Compute predicted previous image µ_t
prev_variance = self.noise_scheduler.sample_variance( # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
t, prev_image.shape, device=torch_device, generator=generator pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
)
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# 5. sample x_{t-1} ~ N(prev_image, prev_variance) = add variance to predicted image # and sample from it to get previous image
sampled_prev_image = prev_image + prev_variance # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
image = sampled_prev_image if t > 0:
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t)).sqrt()
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
sampled_variance = variance * noise
# sampled_variance = self.noise_scheduler.sample_variance(
# t, pred_prev_image.shape, device=torch_device, generator=generator
# )
prev_image = pred_prev_image + sampled_variance
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
return image return image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment