Commit fe313730 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

improve

parent 3a5c65d5
...@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th ...@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example: Example:
```python ```python
from diffusers import UNetModel, GaussianDiffusion from diffusers import UNetModel, GaussianDDPMScheduler
import torch import torch
# 1. Load model # 1. Load model
...@@ -40,7 +40,7 @@ time_step = torch.tensor([10]) ...@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image = unet(dummy_noise, time_step) image = unet(dummy_noise, time_step)
# 3. Load sampler # 3. Load sampler
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# 4. Sample image from sampler passing the model # 4. Sample image from sampler passing the model
image = sampler.sample(model, batch_size=1) image = sampler.sample(model, batch_size=1)
...@@ -54,12 +54,12 @@ print(image) ...@@ -54,12 +54,12 @@ print(image)
Example: Example:
```python ```python
from diffusers import UNetModel, GaussianDiffusion from diffusers import UNetModel, GaussianDDPMScheduler
from modeling_ddpm import DDPM from modeling_ddpm import DDPM
import tempfile import tempfile
unet = UNetModel.from_pretrained("fusing/ddpm_dummy") unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline # compose Diffusion Pipeline
ddpm = DDPM(unet, sampler) ddpm = DDPM(unet, sampler)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion from diffusers import UNetModel, GaussianDDPMScheduler
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
unet = UNetModel.from_pretrained("fusing/ddpm_dummy") import PIL.Image
diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy") import tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model # 2. Do one denoising step with model
batch_size, num_channels, height, width = 1, 3, 32, 32 #batch_size, num_channels, height, width = 1, 3, 256, 256
dummy_noise = torch.ones((batch_size, num_channels, height, width)) #
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
TIME_STEPS = 10 #
#
# Helper # Helper
def extract(a, t, x_shape): #def noise_like(shape, device, repeat=False):
b, *_ = t.shape # def repeat_noise():
out = a.gather(-1, t) # return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
return out.reshape(b, *((1,) * (len(x_shape) - 1))) #
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def real_fn():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t = noise_image
# 2: for t = T, ...., 1 do
for i in reversed(range(TIME_STEPS)):
t = torch.tensor([i]).to(torch_device)
# 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, torch_device)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
with torch.no_grad():
pred_noise = unet(x_t, t) # pred epsilon_theta
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
# pred_x.clamp_(-1.0, 1.0)
# pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise)
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance = posterior_log_variance_clipped[t]
b, *_, device = *x_t.shape, x_t.device
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------#
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
x_t = x_t_1
print(x_t.abs().sum())
def post_process_to_image(x_t):
image = x_t.cpu().permute(0, 2, 3, 1)
image = (image + 1.0) * 127.5
image = image.numpy().astype(np.uint8)
return PIL.Image.fromarray(image[0])
from pytorch_diffusion import Diffusion
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
device = "cuda"
scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10)
import ipdb; ipdb.set_trace()
model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device)
def noise_like(shape, device, repeat=False): torch.manual_seed(0)
def repeat_noise(): next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device)
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
# Schedule
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)):
# define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
betas = cosine_beta_schedule(TIME_STEPS) # predict noise residual
alphas = 1.0 - betas with torch.no_grad():
alphas_cumprod = torch.cumprod(alphas, axis=0) noise_residual = model(next_image, t)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) # compute prev image from noise
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
image = clip_coeff * pred_mean + image_coeff * next_image
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # sample variance
posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20)) variance = scheduler.sample_variance(t, image.shape, device=device)
# sample previous image
sampled_image = image + variance
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) next_image = sampled_image
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
torch.manual_seed(0)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf image = post_process_to_image(next_image)
# 1: x_t ~ N(0,1) image.save("example_new.png")
x_t = dummy_noise
# 2: for t = T, ...., 1 do
for i in reversed(range(TIME_STEPS)):
t = torch.tensor([i])
# 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, "cpu")
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
pred_noise = unet(x_t, t) # pred epsilon_theta
pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
pred_x.clamp_(-1.0, 1.0)
# pred mean
posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
# --------------------------------------------------------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape)
b, *_, device = *x_t.shape, x_t.device
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------#
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
# --------------------------------------------------------------------#
x_t = x_t_1
#!/usr/bin/env python3 #!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion
from modeling_ddpm import DDPM
import tempfile import tempfile
from diffusers import GaussianDDPMScheduler, UNetModel
from modeling_ddpm import DDPM
unet = UNetModel.from_pretrained("fusing/ddpm_dummy") unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline # compose Diffusion Pipeline
ddpm = DDPM(unet, sampler) ddpm = DDPM(unet, sampler)
......
...@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline ...@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
def __init__(self, unet, gaussian_sampler): def __init__(self, unet, gaussian_sampler):
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler) super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import torch import torch
from diffusers import GaussianDiffusion, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8)) model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1 training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images) loss = diffusion(training_images)
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
__version__ = "0.0.1" __version__ = "0.0.1"
from .modeling_utils import PreTrainedModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .samplers.gaussian import GaussianDiffusion
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .modeling_utils import PreTrainedModel from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import copy import copy
import inspect
import json import json
import os import os
import re import re
import inspect
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError from requests import HTTPError
...@@ -186,6 +186,11 @@ class Config: ...@@ -186,6 +186,11 @@ class Config:
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self") expected_keys.remove("self")
for key in expected_keys:
if key in kwargs:
# overwrite key
config_dict[key] = kwargs.pop(key)
passed_keys = set(config_dict.keys()) passed_keys = set(config_dict.keys())
unused_kwargs = kwargs unused_kwargs = kwargs
...@@ -194,17 +199,16 @@ class Config: ...@@ -194,17 +199,16 @@ class Config:
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
f"{expected_keys - passed_keys} was not found in config. " f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
f"Values will be initialized to default values."
) )
return config_dict, unused_kwargs return config_dict, unused_kwargs
@classmethod @classmethod
def from_config( def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs config_dict, unused_kwargs = cls.get_config_dict(
): pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) )
model = cls(**config_dict) model = cls(**config_dict)
......
...@@ -24,6 +24,7 @@ from requests import HTTPError ...@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
...@@ -33,7 +34,6 @@ from transformers.utils import ( ...@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
logging, logging,
CONFIG_NAME,
) )
......
...@@ -17,376 +17,337 @@ ...@@ -17,376 +17,337 @@
import copy import copy
import math import math
from functools import partial
from inspect import isfunction
from pathlib import Path from pathlib import Path
import torch import torch
from torch import einsum, nn from torch import nn
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam from torch.optim import Adam
from torch.utils import data from torch.utils import data
from einops import rearrange from torchvision import transforms, utils
from torchvision import utils, transforms from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import Config from ..configuration_utils import Config
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from PIL import Image
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# small helper modules
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
class SinusoidalPosEmb(nn.Module): half_dim = embedding_dim // 2
def __init__(self, dim): emb = math.log(10000) / (half_dim - 1)
super().__init__() emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
self.dim = dim emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
def forward(self, x): emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
device = x.device if embedding_dim % 2 == 1: # zero pad
half_dim = self.dim // 2 emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
emb = math.log(10000) / (half_dim - 1) return emb
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def Upsample(dim): def nonlinearity(x):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1) # swish
return x * torch.sigmoid(x)
def Downsample(dim): def Normalize(in_channels):
return nn.Conv2d(dim, dim, 4, 2, 1) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LayerNorm(nn.Module): class Upsample(nn.Module):
def __init__(self, dim, eps=1e-5): def __init__(self, in_channels, with_conv):
super().__init__() super().__init__()
self.eps = eps self.with_conv = with_conv
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) if self.with_conv:
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x): def forward(self, x):
var = torch.var(x, dim=1, unbiased=False, keepdim=True) x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
mean = torch.mean(x, dim=1, keepdim=True) if self.with_conv:
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b x = self.conv(x)
return x
class PreNorm(nn.Module): class Downsample(nn.Module):
def __init__(self, dim, fn): def __init__(self, in_channels, with_conv):
super().__init__() super().__init__()
self.fn = fn self.with_conv = with_conv
self.norm = LayerNorm(dim) if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x): def forward(self, x):
x = self.norm(x) if self.with_conv:
return self.fn(x) pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
# building block modules else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x return x
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__() super().__init__()
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.block1 = Block(dim, dim_out, groups=groups) self.out_channels = out_channels
self.block2 = Block(dim_out, dim_out, groups=groups) self.use_conv_shortcut = conv_shortcut
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
self.norm1 = Normalize(in_channels)
def forward(self, x, time_emb=None): self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
scale_shift = None self.norm2 = Normalize(out_channels)
if exists(self.mlp) and exists(time_emb): self.dropout = torch.nn.Dropout(dropout)
time_emb = self.mlp(time_emb) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
time_emb = rearrange(time_emb, "b c -> b c 1 1") if self.in_channels != self.out_channels:
scale_shift = time_emb.chunk(2, dim=1) if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
h = self.block1(x, scale_shift=scale_shift) else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
h = self.block2(h)
return h + self.res_conv(x) def forward(self, x, temb):
h = x
h = self.norm1(h)
class LinearAttention(nn.Module): h = nonlinearity(h)
def __init__(self, dim, heads=4, dim_head=32): h = self.conv1(h)
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__() super().__init__()
self.scale = dim_head**-0.5 self.in_channels = in_channels
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape h_ = x
qkv = self.to_qkv(x).chunk(3, dim=1) h_ = self.norm(h_)
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) q = self.q(h_)
k = self.k(h_)
q = q.softmax(dim=-2) v = self.v(h_)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) # compute attention
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) b, c, h, w = q.shape
return self.to_out(out) q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
class Attention(nn.Module): h_ = self.proj_out(h_)
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v) return x + h_
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
class UNetModel(PreTrainedModel, Config): class UNetModel(PreTrainedModel, Config):
def __init__( def __init__(
self, self,
dim=64, ch=128,
dim_mults=(1, 2, 4, 8), out_ch=3,
init_dim=None, ch_mult=(1, 1, 2, 2, 4, 4),
out_dim=None, num_res_blocks=2,
channels=3, attn_resolutions=(16,),
with_time_emb=True, dropout=0.0,
resnet_block_groups=8, resamp_with_conv=True,
learned_variance=False, in_channels=3,
resolution=256,
): ):
super().__init__() super().__init__()
self.register( self.register(
dim=dim, ch=ch,
dim_mults=dim_mults, out_ch=out_ch,
init_dim=init_dim, ch_mult=ch_mult,
out_dim=out_dim, num_res_blocks=num_res_blocks,
channels=channels, attn_resolutions=attn_resolutions,
with_time_emb=with_time_emb, dropout=dropout,
resnet_block_groups=resnet_block_groups, resamp_with_conv=resamp_with_conv,
learned_variance=learned_variance, in_channels=in_channels,
resolution=resolution,
)
ch_mult = tuple(ch_mult)
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
) )
init_dim = None
out_dim = None
channels = 3
with_time_emb = True
resnet_block_groups = 8
learned_variance = False
# determine dimensions
dim_mults = dim_mults
dim = dim
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
if with_time_emb: # downsampling
time_dim = dim * 4 self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) curr_res = resolution
) in_ch_mult = (1,) + ch_mult
else: self.down = nn.ModuleList()
time_dim = None for i_level in range(self.num_resolutions):
self.time_mlp = None block = nn.ModuleList()
attn = nn.ModuleList()
# layers block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
self.downs = nn.ModuleList([]) for i_block in range(self.num_res_blocks):
self.ups = nn.ModuleList([]) block.append(
num_resolutions = len(in_out) ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
for ind, (dim_in, dim_out) in enumerate(in_out): )
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
) )
) block_in = block_out
if curr_res in attn_resolutions:
default_out_dim = channels * (1 if not learned_variance else 2) attn.append(AttnBlock(block_in))
self.out_dim = default(out_dim, default_out_dim) down = nn.Module()
down.block = block
self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1)) down.attn = attn
if i_level != self.num_resolutions - 1:
def forward(self, x, time): down.downsample = Downsample(block_in, resamp_with_conv)
x = self.init_conv(x) curr_res = curr_res // 2
self.down.append(down)
t = self.time_mlp(time) if exists(self.time_mlp) else None
# middle
h = [] self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
for block1, block2, attn, downsample in self.downs: in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
x = block1(x, t) )
x = block2(x, t) self.mid.attn_1 = AttnBlock(block_in)
x = attn(x) self.mid.block_2 = ResnetBlock(
h.append(x) in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
x = downsample(x) )
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x) # upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t):
assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(t):
t = torch.tensor([t], dtype=torch.long, device=x.device)
# timestep embedding
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
# dataset classes # dataset classes
class Dataset(data.Dataset): class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
super().__init__() super().__init__()
self.folder = folder self.folder = folder
self.image_size = image_size self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.transform = transforms.Compose( self.transform = transforms.Compose([
[ transforms.Resize(image_size),
transforms.Resize(image_size), transforms.RandomHorizontalFlip(),
transforms.RandomHorizontalFlip(), transforms.CenterCrop(image_size),
transforms.CenterCrop(image_size), transforms.ToTensor()
transforms.ToTensor(), ])
]
)
def __len__(self): def __len__(self):
return len(self.paths) return len(self.paths)
...@@ -398,10 +359,38 @@ class Dataset(data.Dataset): ...@@ -398,10 +359,38 @@ class Dataset(data.Dataset):
# trainer class # trainer class
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Trainer(object):
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class Trainer(object):
def __init__( def __init__(
self, self,
diffusion_model, diffusion_model,
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
import importlib
from .configuration_utils import Config
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
from transformers.utils import logging from transformers.utils import logging
from .configuration_utils import Config
INDEX_FILE = "diffusion_model.pt" INDEX_FILE = "diffusion_model.pt"
...@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__) ...@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"],
"GaussianDiffusion": ["save_config", "from_config"], "GaussianDDPMScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"],
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from inspect import isfunction
from tqdm import tqdm
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "sampler_config.json"
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# small helper modules
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# gaussian diffusion trainer class
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
image_size,
channels=3,
timesteps=1000,
loss_type="l1",
objective="pred_noise",
beta_schedule="cosine",
):
super().__init__()
self.register(
image_size=image_size,
channels=channels,
timesteps=timesteps,
loss_type=loss_type,
objective=objective,
beta_schedule=beta_schedule,
)
self.channels = channels
self.image_size = image_size
self.objective = objective
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# helper function to register buffer from float64 to float32
def register_buffer(name, val):
self.register_buffer(name, val.to(torch.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
register_buffer(
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised: bool):
model_output = model(x, t)
if self.objective == "pred_noise":
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
elif self.objective == "pred_x0":
x_start = model_output
else:
raise ValueError(f"unknown objective {self.objective}")
if clip_denoised:
x_start.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, model, x, t, noise=None, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
if noise is None:
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return result
@torch.no_grad()
def p_sample_loop(self, model, shape):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
img = unnormalize_to_zero_to_one(img)
return img
@torch.no_grad()
def sample(self, model, batch_size=16):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
@torch.no_grad()
def interpolate(self, model, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")
def p_losses(self, model, x_start, t, noise=None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = model(x, t)
if self.objective == "pred_noise":
target = noise
elif self.objective == "pred_x0":
target = x_start
else:
raise ValueError(f"unknown objective {self.objective}")
loss = self.loss_fn(model_out, target)
return loss
def forward(self, model, img, *args, **kwargs):
b, _, h, w, device, img_size, = (
*img.shape,
img.device,
self.image_size,
)
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = normalize_to_neg_one_to_one(img)
return self.p_losses(model, img, t, *args, **kwargs)
...@@ -16,4 +16,4 @@ ...@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .gaussian import GaussianDiffusion from .gaussian_ddpm import GaussianDDPMScheduler
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
class GaussianDDPMScheduler(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
variance_type="fixed_small",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
variance_type=variance_type,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1)
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
...@@ -16,13 +16,45 @@ ...@@ -16,13 +16,45 @@
import random import random
import tempfile import tempfile
import unittest import unittest
import os
from distutils.util import strtobool
import torch import torch
from diffusers import GaussianDiffusion, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
global_rng = random.Random() global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def floats_tensor(shape, scale=1.0, rng=None, name=None): def floats_tensor(shape, scale=1.0, rng=None, name=None):
...@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase): ...@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return (noise, time_step) return (noise, time_step)
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2) model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
...@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase): ...@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class SamplerTesterMixin(unittest.TestCase): class SamplerTesterMixin(unittest.TestCase):
@property @slow
def dummy_model(self): def test_sample(self):
return UNetModel.from_pretrained("fusing/ddpm_dummy") generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
def test_from_pretrained_save_pretrained(self):
sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1") # 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
with tempfile.TemporaryDirectory() as tmpdirname: model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
sampler.save_config(tmpdirname)
new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False) # 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
model = self.dummy_model
# 3. Denoise
torch.manual_seed(0) for t in reversed(range(len(scheduler))):
sampled_out = sampler.sample(model, batch_size=1) # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3
def test_sample_fast(self):
# 1. Load models
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
torch.manual_seed(0) torch.manual_seed(0)
sampled_out_new = new_sampler.sample(model, batch_size=1) image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output" # 3. Denoise
for t in reversed(range(len(scheduler))):
def test_from_pretrained_hub(self): # i) define coefficients for time step t
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
model = self.dummy_model clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
sampled_out = sampler.sample(model, batch_size=1) clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
assert sampled_out is not None, "Make sure output is not None" # ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment