Commit fe313730 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

improve

parent 3a5c65d5
......@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example:
```python
from diffusers import UNetModel, GaussianDiffusion
from diffusers import UNetModel, GaussianDDPMScheduler
import torch
# 1. Load model
......@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image = unet(dummy_noise, time_step)
# 3. Load sampler
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# 4. Sample image from sampler passing the model
image = sampler.sample(model, batch_size=1)
......@@ -54,12 +54,12 @@ print(image)
Example:
```python
from diffusers import UNetModel, GaussianDiffusion
from diffusers import UNetModel, GaussianDDPMScheduler
from modeling_ddpm import DDPM
import tempfile
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline
ddpm = DDPM(unet, sampler)
......
#!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion
from diffusers import UNetModel, GaussianDDPMScheduler
import torch
import torch.nn.functional as F
import numpy as np
import PIL.Image
import tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model
#batch_size, num_channels, height, width = 1, 3, 256, 256
#
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
#
#
# Helper
#def noise_like(shape, device, repeat=False):
# def repeat_noise():
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
#
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def real_fn():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t = noise_image
# 2: for t = T, ...., 1 do
for i in reversed(range(TIME_STEPS)):
t = torch.tensor([i]).to(torch_device)
# 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, torch_device)
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy")
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
with torch.no_grad():
pred_noise = unet(x_t, t) # pred epsilon_theta
# 2. Do one denoising step with model
batch_size, num_channels, height, width = 1, 3, 32, 32
dummy_noise = torch.ones((batch_size, num_channels, height, width))
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
# pred_x.clamp_(-1.0, 1.0)
# pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise)
TIME_STEPS = 10
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance = posterior_log_variance_clipped[t]
b, *_, device = *x_t.shape, x_t.device
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------#
# Helper
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
x_t = x_t_1
print(x_t.abs().sum())
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
def post_process_to_image(x_t):
image = x_t.cpu().permute(0, 2, 3, 1)
image = (image + 1.0) * 127.5
image = image.numpy().astype(np.uint8)
return repeat_noise() if repeat else noise()
return PIL.Image.fromarray(image[0])
# Schedule
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
from pytorch_diffusion import Diffusion
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
betas = cosine_beta_schedule(TIME_STEPS)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
device = "cuda"
scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
import ipdb; ipdb.set_trace()
model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device)
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
torch.manual_seed(0)
next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t = dummy_noise
# 2: for t = T, ...., 1 do
for i in reversed(range(TIME_STEPS)):
t = torch.tensor([i])
# 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, "cpu")
for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)):
# define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
pred_noise = unet(x_t, t) # pred epsilon_theta
pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
pred_x.clamp_(-1.0, 1.0)
# pred mean
posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
# --------------------------------------------------------------------#
# predict noise residual
with torch.no_grad():
noise_residual = model(next_image, t)
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape)
b, *_, device = *x_t.shape, x_t.device
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------#
# compute prev image from noise
pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
image = clip_coeff * pred_mean + image_coeff * next_image
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
# sample variance
variance = scheduler.sample_variance(t, image.shape, device=device)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
# --------------------------------------------------------------------#
# sample previous image
sampled_image = image + variance
x_t = x_t_1
next_image = sampled_image
image = post_process_to_image(next_image)
image.save("example_new.png")
#!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion
from modeling_ddpm import DDPM
import tempfile
from diffusers import GaussianDDPMScheduler, UNetModel
from modeling_ddpm import DDPM
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline
ddpm = DDPM(unet, sampler)
......
......@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline):
def __init__(self, unet, gaussian_sampler):
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
......
#!/usr/bin/env python3
import torch
from diffusers import GaussianDiffusion, UNetModel
from diffusers import GaussianDDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
......
......@@ -4,8 +4,7 @@
__version__ = "0.0.1"
from .modeling_utils import PreTrainedModel
from .models.unet import UNetModel
from .samplers.gaussian import GaussianDiffusion
from .pipeline_utils import DiffusionPipeline
from .modeling_utils import PreTrainedModel
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
......@@ -17,10 +17,10 @@
import copy
import inspect
import json
import os
import re
import inspect
from typing import Any, Dict, Tuple, Union
from requests import HTTPError
......@@ -186,6 +186,11 @@ class Config:
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
for key in expected_keys:
if key in kwargs:
# overwrite key
config_dict[key] = kwargs.pop(key)
passed_keys = set(config_dict.keys())
unused_kwargs = kwargs
......@@ -194,17 +199,16 @@ class Config:
if len(expected_keys - passed_keys) > 0:
logger.warn(
f"{expected_keys - passed_keys} was not found in config. "
f"Values will be initialized to default values."
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
return config_dict, unused_kwargs
@classmethod
def from_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
):
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict, unused_kwargs = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
model = cls(**config_dict)
......
......@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils
from transformers.utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
......@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode,
is_remote_url,
logging,
CONFIG_NAME,
)
......
This diff is collapsed.
......@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from typing import Optional, Union
import importlib
from .configuration_utils import Config
# CHANGE to diffusers.utils
from transformers.utils import logging
from .configuration_utils import Config
INDEX_FILE = "diffusion_model.pt"
......@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"GaussianDiffusion": ["save_config", "from_config"],
"GaussianDDPMScheduler": ["save_config", "from_config"],
},
"transformers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from inspect import isfunction
from tqdm import tqdm
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "sampler_config.json"
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# small helper modules
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# gaussian diffusion trainer class
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
image_size,
channels=3,
timesteps=1000,
loss_type="l1",
objective="pred_noise",
beta_schedule="cosine",
):
super().__init__()
self.register(
image_size=image_size,
channels=channels,
timesteps=timesteps,
loss_type=loss_type,
objective=objective,
beta_schedule=beta_schedule,
)
self.channels = channels
self.image_size = image_size
self.objective = objective
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# helper function to register buffer from float64 to float32
def register_buffer(name, val):
self.register_buffer(name, val.to(torch.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
register_buffer(
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised: bool):
model_output = model(x, t)
if self.objective == "pred_noise":
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
elif self.objective == "pred_x0":
x_start = model_output
else:
raise ValueError(f"unknown objective {self.objective}")
if clip_denoised:
x_start.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, model, x, t, noise=None, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
if noise is None:
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return result
@torch.no_grad()
def p_sample_loop(self, model, shape):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
img = unnormalize_to_zero_to_one(img)
return img
@torch.no_grad()
def sample(self, model, batch_size=16):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
@torch.no_grad()
def interpolate(self, model, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")
def p_losses(self, model, x_start, t, noise=None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = model(x, t)
if self.objective == "pred_noise":
target = noise
elif self.objective == "pred_x0":
target = x_start
else:
raise ValueError(f"unknown objective {self.objective}")
loss = self.loss_fn(model_out, target)
return loss
def forward(self, model, img, *args, **kwargs):
b, _, h, w, device, img_size, = (
*img.shape,
img.device,
self.image_size,
)
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = normalize_to_neg_one_to_one(img)
return self.p_losses(model, img, t, *args, **kwargs)
......@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .gaussian import GaussianDiffusion
from .gaussian_ddpm import GaussianDDPMScheduler
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
class GaussianDDPMScheduler(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
variance_type="fixed_small",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
variance_type=variance_type,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1)
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
......@@ -16,13 +16,45 @@
import random
import tempfile
import unittest
import os
from distutils.util import strtobool
import torch
from diffusers import GaussianDiffusion, UNetModel
from diffusers import GaussianDDPMScheduler, UNetModel
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
......@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return (noise, time_step)
def test_from_pretrained_save_pretrained(self):
model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
......@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class SamplerTesterMixin(unittest.TestCase):
@property
def dummy_model(self):
return UNetModel.from_pretrained("fusing/ddpm_dummy")
def test_from_pretrained_save_pretrained(self):
sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1")
with tempfile.TemporaryDirectory() as tmpdirname:
sampler.save_config(tmpdirname)
new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False)
model = self.dummy_model
torch.manual_seed(0)
sampled_out = sampler.sample(model, batch_size=1)
@slow
def test_sample(self):
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
# 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
# 3. Denoise
for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3
def test_sample_fast(self):
# 1. Load models
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
torch.manual_seed(0)
sampled_out_new = new_sampler.sample(model, batch_size=1)
assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output"
def test_from_pretrained_hub(self):
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
model = self.dummy_model
sampled_out = sampler.sample(model, batch_size=1)
assert sampled_out is not None, "Make sure output is not None"
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
# 3. Denoise
for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment