Commit 7ac909d6 authored by patil-suraj's avatar patil-suraj
Browse files

make ldm work, add classifier free guidence

parent 9a1a6e97
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import math import math
import numpy as np import numpy as np
import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
...@@ -740,30 +740,29 @@ class DiagonalGaussianDistribution(object): ...@@ -740,30 +740,29 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
- 1.0 dim=[1, 2, 3])
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]): def nll(self, sample, dims=[1,2,3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self): def mode(self):
return self.mean return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
...@@ -835,7 +834,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -835,7 +834,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end, give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x): def encode(self, x):
...@@ -864,7 +863,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -864,7 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
super().__init__() super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
...@@ -873,7 +872,11 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -873,7 +872,11 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
self.bert.to(torch_device) self.bert.to(torch_device)
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
# get text embedding # get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
...@@ -886,46 +889,75 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -886,46 +889,75 @@ class LatentDiffusion(DiffusionPipeline):
device=torch_device, device=torch_device,
generator=generator, generator=generator,
) )
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1 # 1. predict noise residual
if guidance_scale == 1.0:
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
context = text_embedding
image_in = image
else:
image_in = torch.cat([image] * 2)
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
context = torch.cat([uncond_embeddings, text_embedding])
with torch.no_grad():
pred_noise_t = self.unet(image_in, timesteps, context=context)
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. get actual t and t-1
train_step = inference_step_times[t] train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else -1 prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas # 3. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt() beta_prod_t = 1 - alpha_prod_t
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt() beta_prod_t_prev = 1 - alpha_prod_t_prev
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (
(alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward # 4. Compute predicted previous image from predicted noise
with torch.no_grad(): # First: compute predicted original image from predicted noise also called
train_step = torch.tensor([train_step] * image.shape[0], device=torch_device) # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
noise_residual = self.unet(image, train_step, context=text_embedding) pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# Second: Clip "predicted x_0"
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
std_dev_t = eta * std_dev_t
# predict mean of prev image # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual) pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
pred_mean = torch.clamp(pred_mean, -1, 1)
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM # Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0: if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
image = pred_mean + coeff_1 * noise prev_image = pred_prev_image + std_dev_t * noise
else: else:
image = pred_mean prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
image = 1 / 0.18215 * image image = 1 / 0.18215 * image
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment