".github/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "4a044c646663231905c5659d2f83d0abe60e0ed8"
Commit ca94e36c authored by patil-suraj's avatar patil-suraj
Browse files

fix LatentDiffusion

parent 76f0f1d4
...@@ -900,11 +900,12 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -900,11 +900,12 @@ class LatentDiffusion(DiffusionPipeline):
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn( image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
device=torch_device,
generator=generator, generator=generator,
) )
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -937,46 +938,17 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -937,46 +938,17 @@ class LatentDiffusion(DiffusionPipeline):
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. get actual t and t-1 # 2. predict previous mean of image x_t-1
train_step = inference_step_times[t] pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# 3. optionally sample variance
# 3. compute alphas, betas variance = 0
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) if eta > 0:
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
beta_prod_t = 1 - alpha_prod_t variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# Second: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
std_dev_t = eta * std_dev_t
# Third: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
# Forth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0:
noise = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
)
noise = noise.to(torch_device)
prev_image = pred_prev_image + std_dev_t * noise
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = prev_image image = pred_prev_image + variance
# scale and decode image with vae # scale and decode image with vae
image = 1 / 0.18215 * image image = 1 / 0.18215 * image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment