Commit 7ac909d6 authored by patil-suraj's avatar patil-suraj
Browse files

make ldm work, add classifier free guidence

parent 9a1a6e97
......@@ -2,10 +2,10 @@
import math
import numpy as np
import tqdm
import torch
import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
......@@ -740,30 +740,29 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.0])
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
......@@ -835,7 +834,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end,
)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x):
......@@ -864,7 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
......@@ -874,6 +873,10 @@ class LatentDiffusion(DiffusionPipeline):
self.vqvae.to(torch_device)
self.bert.to(torch_device)
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]
......@@ -886,45 +889,74 @@ class LatentDiffusion(DiffusionPipeline):
device=torch_device,
generator=generator,
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
# 1. predict noise residual
if guidance_scale == 1.0:
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
context = text_embedding
image_in = image
else:
image_in = torch.cat([image] * 2)
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
context = torch.cat([uncond_embeddings, text_embedding])
with torch.no_grad():
pred_noise_t = self.unet(image_in, timesteps, context=context)
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas
# 3. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (
(alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# model forward
with torch.no_grad():
train_step = torch.tensor([train_step] * image.shape[0], device=torch_device)
noise_residual = self.unet(image, train_step, context=text_embedding)
# 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# predict mean of prev image
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
pred_mean = torch.clamp(pred_mean, -1, 1)
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
# Second: Clip "predicted x_0"
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
std_dev_t = eta * std_dev_t
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
image = pred_mean + coeff_1 * noise
prev_image = pred_prev_image + std_dev_t * noise
else:
image = pred_mean
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment