"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "4150d0ac2f4e46572f0b16d51b2658700b0c2bdc"
Commit 9a1a6e97 authored by patil-suraj's avatar patil-suraj
Browse files

rebase

parents f1823bbe 1122c707
...@@ -45,28 +45,44 @@ image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.re ...@@ -45,28 +45,44 @@ image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.re
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # 1. predict noise residual
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) with torch.no_grad():
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) pred_noise_t = self.unet(image, t)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) # 2. compute alphas, betas
alpha_prod_t = scheduler.get_alpha_prod(t)
# ii) predict noise residual alpha_prod_t_prev = scheduler.get_alpha_prod(t - 1)
with torch.no_grad(): beta_prod_t = 1 - alpha_prod_t
noise_residual = model(image, t) beta_prod_t_prev = 1 - alpha_prod_t_prev
# iii) compute predicted image from residual # 3. compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # First: compute predicted original image from predicted noise also called
pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_mean = torch.clamp(pred_mean, -1, 1) pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
prev_image = clipped_coeff * pred_mean + image_coeff * image
# Second: Clip "predicted x_0"
# iv) sample variance pred_original_image = torch.clamp(pred_original_image, -1, 1)
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# Third: Compute coefficients for pred_original_image x_0 and current image x_t
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
sampled_prev_image = prev_image + prev_variance pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * scheduler.get_beta(t)) / beta_prod_t
image = sampled_prev_image current_image_coeff = scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# Fourth: Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0:
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
noise = scheduler.sample_noise(image.shape, device=image.device, generator=generator)
prev_image = pred_prev_image + variance * noise
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
# process image to PIL # process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = image.cpu().permute(0, 2, 3, 1)
......
...@@ -42,7 +42,7 @@ class DDIM(DiffusionPipeline): ...@@ -42,7 +42,7 @@ class DDIM(DiffusionPipeline):
generator=generator, generator=generator,
) )
# See formulas (9), (10) and (7) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper> # Notation (<variable name> -> <name in paper>
...@@ -64,11 +64,10 @@ class DDIM(DiffusionPipeline): ...@@ -64,11 +64,10 @@ class DDIM(DiffusionPipeline):
# 3. compute alphas, betas # 3. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
beta_prod_t = (1 - alpha_prod_t) beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = (1 - alpha_prod_t_prev) beta_prod_t_prev = 1 - alpha_prod_t_prev
# 4. Compute predicted previous image from predicted noise # 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called # First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
......
...@@ -41,7 +41,7 @@ class DDPM(DiffusionPipeline): ...@@ -41,7 +41,7 @@ class DDPM(DiffusionPipeline):
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
noise_residual = self.unet(image, t) pred_noise_t = self.unet(image, t)
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t) alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
...@@ -50,24 +50,38 @@ class DDPM(DiffusionPipeline): ...@@ -50,24 +50,38 @@ class DDPM(DiffusionPipeline):
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. compute predicted image from residual # 3. compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # First: compute predicted original image from predicted noise also called
# First: Compute inner formula # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_mean = (1 / alpha_prod_t.sqrt()) * (image - beta_prod_t.sqrt() * noise_residual) pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# Second: Clip
pred_mean = torch.clamp(pred_mean, -1, 1) # Second: Clip "predicted x_0"
# Third: Compute outer coefficients pred_original_image = torch.clamp(pred_original_image, -1, 1)
pred_mean_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
image_coeff = (beta_prod_t_prev * self.noise_scheduler.get_alpha(t).sqrt()) / beta_prod_t # Third: Compute coefficients for pred_original_image x_0 and current image x_t
# Fourth: Compute outer formula # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
prev_image = pred_mean_coeff * pred_mean + image_coeff * image pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# 4. sample variance # Fourth: Compute predicted previous image µ_t
prev_variance = self.noise_scheduler.sample_variance( # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
t, prev_image.shape, device=torch_device, generator=generator pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
)
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# 5. sample x_{t-1} ~ N(prev_image, prev_variance) = add variance to predicted image # and sample from it to get previous image
sampled_prev_image = prev_image + prev_variance # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
image = sampled_prev_image if t > 0:
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
# TODO(PVP):
# This variance seems to be good enough for inference - check if those `fix_small`, `fix_large`
# are really only needed for training or also for inference
# Also note LDM only uses "fixed_small";
# glide seems to use a weird mix of the two: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
sampled_variance = variance * noise
prev_image = pred_prev_image + sampled_variance
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
return image return image
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import math import math
import numpy as np import numpy as np
import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object): ...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self): def mode(self):
return self.mean return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
...@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end, give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x): def encode(self, x):
...@@ -872,7 +873,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -872,7 +873,7 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
self.bert.to(torch_device) self.bert.to(torch_device)
# get text embedding # get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
......
...@@ -113,7 +113,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -113,7 +113,7 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
......
...@@ -62,6 +62,9 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -62,6 +62,9 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
if variance_type == "fixed_small": if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20)) log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large": elif variance_type == "fixed_large":
...@@ -84,17 +87,6 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -84,17 +87,6 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.tensor(1.0) return torch.tensor(1.0)
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None): def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic # always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment