Commit 97226d97 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

upload cleaner scripts

parent 8d649443
...@@ -64,8 +64,8 @@ class DDIM(DiffusionPipeline): ...@@ -64,8 +64,8 @@ class DDIM(DiffusionPipeline):
# 3. compute alphas, betas # 3. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
beta_prod_t = (1 - alpha_prod_t) beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = (1 - alpha_prod_t_prev) beta_prod_t_prev = 1 - alpha_prod_t_prev
# 4. Compute predicted previous image from predicted noise # 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called # First: compute predicted original image from predicted noise also called
......
...@@ -46,8 +46,8 @@ class DDPM(DiffusionPipeline): ...@@ -46,8 +46,8 @@ class DDPM(DiffusionPipeline):
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t) alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1) alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
beta_prod_t = (1 - alpha_prod_t) beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = (1 - alpha_prod_t_prev) beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. compute predicted image from residual # 3. compute predicted image from residual
# First: compute predicted original image from predicted noise also called # First: compute predicted original image from predicted noise also called
...@@ -69,12 +69,14 @@ class DDPM(DiffusionPipeline): ...@@ -69,12 +69,14 @@ class DDPM(DiffusionPipeline):
# and sample from it to get previous image # and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0: if t > 0:
variance = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t)).sqrt() variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
# TODO(PVP):
# This variance seems to be good enough for inference - check if those `fix_small`, `fix_large`
# are really only needed for training or also for inference
# Also note LDM only uses "fixed_small";
# glide seems to use a weird mix of the two: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
sampled_variance = variance * noise sampled_variance = variance * noise
# sampled_variance = self.noise_scheduler.sample_variance(
# t, pred_prev_image.shape, device=torch_device, generator=generator
# )
prev_image = pred_prev_image + sampled_variance prev_image = pred_prev_image + sampled_variance
else: else:
prev_image = pred_prev_image prev_image = pred_prev_image
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import math import math
import numpy as np import numpy as np
import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object): ...@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def kl(self, other=None): def kl(self, other=None):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
else: else:
if other is None: if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else: else:
return 0.5 * torch.sum( return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar, + self.var / other.var
dim=[1, 2, 3]) - 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1,2,3]): def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic: if self.deterministic:
return torch.Tensor([0.]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum( return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self): def mode(self):
return self.mean return self.mean
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
...@@ -836,7 +837,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -836,7 +837,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end=give_pre_end, give_pre_end=give_pre_end,
) )
self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1) self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
def encode(self, x): def encode(self, x):
...@@ -874,11 +875,11 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -874,11 +875,11 @@ class LatentDiffusion(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
self.bert.to(torch_device) self.bert.to(torch_device)
# get text embedding # get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(**text_input)[0] text_embedding = self.bert(**text_input)[0]
num_trained_timesteps = self.noise_scheduler.num_timesteps num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
...@@ -927,9 +928,9 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -927,9 +928,9 @@ class LatentDiffusion(DiffusionPipeline):
image = pred_mean + coeff_1 * noise image = pred_mean + coeff_1 * noise
else: else:
image = pred_mean image = pred_mean
image = 1 / image image = 1 / image
image = self.vqvae(image) image = self.vqvae(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0) image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image return image
...@@ -113,7 +113,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -113,7 +113,7 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
......
...@@ -62,6 +62,9 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -62,6 +62,9 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
if variance_type == "fixed_small": if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20)) log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large": elif variance_type == "fixed_large":
...@@ -84,17 +87,6 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin): ...@@ -84,17 +87,6 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return torch.tensor(1.0) return torch.tensor(1.0)
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :]
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None): def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic # always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment