Commit fe313730 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

improve

parent 3a5c65d5
...@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th ...@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example: Example:
```python ```python
from diffusers import UNetModel, GaussianDiffusion from diffusers import UNetModel, GaussianDDPMScheduler
import torch import torch
# 1. Load model # 1. Load model
...@@ -40,7 +40,7 @@ time_step = torch.tensor([10]) ...@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image = unet(dummy_noise, time_step) image = unet(dummy_noise, time_step)
# 3. Load sampler # 3. Load sampler
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# 4. Sample image from sampler passing the model # 4. Sample image from sampler passing the model
image = sampler.sample(model, batch_size=1) image = sampler.sample(model, batch_size=1)
...@@ -54,12 +54,12 @@ print(image) ...@@ -54,12 +54,12 @@ print(image)
Example: Example:
```python ```python
from diffusers import UNetModel, GaussianDiffusion from diffusers import UNetModel, GaussianDDPMScheduler
from modeling_ddpm import DDPM from modeling_ddpm import DDPM
import tempfile import tempfile
unet = UNetModel.from_pretrained("fusing/ddpm_dummy") unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline # compose Diffusion Pipeline
ddpm = DDPM(unet, sampler) ddpm = DDPM(unet, sampler)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion from diffusers import UNetModel, GaussianDDPMScheduler
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
import PIL.Image
import tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model
#batch_size, num_channels, height, width = 1, 3, 256, 256
#
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
#
#
# Helper
#def noise_like(shape, device, repeat=False):
# def repeat_noise():
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
#
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def real_fn():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t = noise_image
# 2: for t = T, ...., 1 do
for i in reversed(range(TIME_STEPS)):
t = torch.tensor([i]).to(torch_device)
# 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, torch_device)
unet = UNetModel.from_pretrained("fusing/ddpm_dummy") # 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy") # ------------------------- MODEL ------------------------------------#
with torch.no_grad():
pred_noise = unet(x_t, t) # pred epsilon_theta
# 2. Do one denoising step with model # pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
batch_size, num_channels, height, width = 1, 3, 32, 32 # pred_x.clamp_(-1.0, 1.0)
dummy_noise = torch.ones((batch_size, num_channels, height, width)) # pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean = coeff[t] * (x_t - noise_coeff[t] * pred_noise)
TIME_STEPS = 10 # ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance = posterior_log_variance_clipped[t]
b, *_, device = *x_t.shape, x_t.device
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------#
# Helper x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
def extract(a, t, x_shape): x_t = x_t_1
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
print(x_t.abs().sum())
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise(): def post_process_to_image(x_t):
return torch.randn(shape, device=device) image = x_t.cpu().permute(0, 2, 3, 1)
image = (image + 1.0) * 127.5
image = image.numpy().astype(np.uint8)
return repeat_noise() if repeat else noise() return PIL.Image.fromarray(image[0])
# Schedule from pytorch_diffusion import Diffusion
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
betas = cosine_beta_schedule(TIME_STEPS)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) device = "cuda"
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=10)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) import ipdb; ipdb.set_trace()
posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
model = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church").to(device)
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
torch.manual_seed(0) torch.manual_seed(0)
next_image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=device)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf for t in tqdm.tqdm(reversed(range(len(scheduler))), total=len(scheduler)):
# 1: x_t ~ N(0,1) # define coefficients for time step t
x_t = dummy_noise clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
# 2: for t = T, ...., 1 do clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
for i in reversed(range(TIME_STEPS)): image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
t = torch.tensor([i]) clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, "cpu")
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz # predict noise residual
# ------------------------- MODEL ------------------------------------# with torch.no_grad():
pred_noise = unet(x_t, t) # pred epsilon_theta noise_residual = model(next_image, t)
pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
pred_x.clamp_(-1.0, 1.0)
# pred mean
posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
# --------------------------------------------------------------------#
# ------------------------- Variance Scheduler -----------------------# # compute prev image from noise
# pred variance pred_mean = clip_image_coeff * next_image - clip_noise_coeff * noise_residual
posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape) pred_mean = torch.clamp(pred_mean, -1, 1)
b, *_, device = *x_t.shape, x_t.device image = clip_coeff * pred_mean + image_coeff * next_image
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------#
x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32) # sample variance
variance = scheduler.sample_variance(t, image.shape, device=device)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function # sample previous image
# --------------------------------------------------------------------# sampled_image = image + variance
x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
# --------------------------------------------------------------------#
x_t = x_t_1 next_image = sampled_image
image = post_process_to_image(next_image)
image.save("example_new.png")
#!/usr/bin/env python3 #!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion
from modeling_ddpm import DDPM
import tempfile import tempfile
from diffusers import GaussianDDPMScheduler, UNetModel
from modeling_ddpm import DDPM
unet = UNetModel.from_pretrained("fusing/ddpm_dummy") unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
# compose Diffusion Pipeline # compose Diffusion Pipeline
ddpm = DDPM(unet, sampler) ddpm = DDPM(unet, sampler)
......
...@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline ...@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
def __init__(self, unet, gaussian_sampler): def __init__(self, unet, gaussian_sampler):
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler) super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import torch import torch
from diffusers import GaussianDiffusion, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8)) model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1 training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images) loss = diffusion(training_images)
......
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
__version__ = "0.0.1" __version__ = "0.0.1"
from .modeling_utils import PreTrainedModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .samplers.gaussian import GaussianDiffusion
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .modeling_utils import PreTrainedModel from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import copy import copy
import inspect
import json import json
import os import os
import re import re
import inspect
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError from requests import HTTPError
...@@ -186,6 +186,11 @@ class Config: ...@@ -186,6 +186,11 @@ class Config:
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self") expected_keys.remove("self")
for key in expected_keys:
if key in kwargs:
# overwrite key
config_dict[key] = kwargs.pop(key)
passed_keys = set(config_dict.keys()) passed_keys = set(config_dict.keys())
unused_kwargs = kwargs unused_kwargs = kwargs
...@@ -194,17 +199,16 @@ class Config: ...@@ -194,17 +199,16 @@ class Config:
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
f"{expected_keys - passed_keys} was not found in config. " f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
f"Values will be initialized to default values."
) )
return config_dict, unused_kwargs return config_dict, unused_kwargs
@classmethod @classmethod
def from_config( def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs config_dict, unused_kwargs = cls.get_config_dict(
): pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) )
model = cls(**config_dict) model = cls(**config_dict)
......
...@@ -24,6 +24,7 @@ from requests import HTTPError ...@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
...@@ -33,7 +34,6 @@ from transformers.utils import ( ...@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
logging, logging,
CONFIG_NAME,
) )
......
This diff is collapsed.
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
import importlib
from .configuration_utils import Config
# CHANGE to diffusers.utils # CHANGE to diffusers.utils
from transformers.utils import logging from transformers.utils import logging
from .configuration_utils import Config
INDEX_FILE = "diffusion_model.pt" INDEX_FILE = "diffusion_model.pt"
...@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__) ...@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"],
"GaussianDiffusion": ["save_config", "from_config"], "GaussianDDPMScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"PreTrainedModel": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"],
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from inspect import isfunction
from tqdm import tqdm
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "sampler_config.json"
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# small helper modules
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# gaussian diffusion trainer class
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
image_size,
channels=3,
timesteps=1000,
loss_type="l1",
objective="pred_noise",
beta_schedule="cosine",
):
super().__init__()
self.register(
image_size=image_size,
channels=channels,
timesteps=timesteps,
loss_type=loss_type,
objective=objective,
beta_schedule=beta_schedule,
)
self.channels = channels
self.image_size = image_size
self.objective = objective
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# helper function to register buffer from float64 to float32
def register_buffer(name, val):
self.register_buffer(name, val.to(torch.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
register_buffer(
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised: bool):
model_output = model(x, t)
if self.objective == "pred_noise":
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
elif self.objective == "pred_x0":
x_start = model_output
else:
raise ValueError(f"unknown objective {self.objective}")
if clip_denoised:
x_start.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, model, x, t, noise=None, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
if noise is None:
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return result
@torch.no_grad()
def p_sample_loop(self, model, shape):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
img = unnormalize_to_zero_to_one(img)
return img
@torch.no_grad()
def sample(self, model, batch_size=16):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
@torch.no_grad()
def interpolate(self, model, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")
def p_losses(self, model, x_start, t, noise=None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = model(x, t)
if self.objective == "pred_noise":
target = noise
elif self.objective == "pred_x0":
target = x_start
else:
raise ValueError(f"unknown objective {self.objective}")
loss = self.loss_fn(model_out, target)
return loss
def forward(self, model, img, *args, **kwargs):
b, _, h, w, device, img_size, = (
*img.shape,
img.device,
self.image_size,
)
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = normalize_to_neg_one_to_one(img)
return self.p_losses(model, img, t, *args, **kwargs)
...@@ -16,4 +16,4 @@ ...@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .gaussian import GaussianDiffusion from .gaussian_ddpm import GaussianDDPMScheduler
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
class GaussianDDPMScheduler(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
variance_type="fixed_small",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
variance_type=variance_type,
)
self.num_timesteps = int(timesteps)
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
if variance_type == "fixed_small":
log_variance = torch.log(variance.clamp(min=1e-20))
elif variance_type == "fixed_large":
log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
self.register_buffer("betas", betas.to(torch.float32))
self.register_buffer("alphas", alphas.to(torch.float32))
self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def sample_variance(self, time_step, shape, device, generator=None):
variance = self.log_variance[time_step]
nonzero_mask = torch.tensor([1 - (time_step == 0)], device=device).float()[None, :].repeat(shape[0], 1)
noise = self.sample_noise(shape, device=device, generator=generator)
sampled_variance = nonzero_mask * (0.5 * variance).exp()
sampled_variance = sampled_variance * noise
return sampled_variance
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.num_timesteps
...@@ -16,13 +16,45 @@ ...@@ -16,13 +16,45 @@
import random import random
import tempfile import tempfile
import unittest import unittest
import os
from distutils.util import strtobool
import torch import torch
from diffusers import GaussianDiffusion, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
global_rng = random.Random() global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError(f"If set, {key} must be yes or no.")
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def floats_tensor(shape, scale=1.0, rng=None, name=None): def floats_tensor(shape, scale=1.0, rng=None, name=None):
...@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase): ...@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return (noise, time_step) return (noise, time_step)
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2) model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
...@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase): ...@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class SamplerTesterMixin(unittest.TestCase): class SamplerTesterMixin(unittest.TestCase):
@property @slow
def dummy_model(self): def test_sample(self):
return UNetModel.from_pretrained("fusing/ddpm_dummy") generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
def test_from_pretrained_save_pretrained(self):
sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1") # 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
with tempfile.TemporaryDirectory() as tmpdirname: model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
sampler.save_config(tmpdirname)
new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False) # 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
model = self.dummy_model
# 3. Denoise
torch.manual_seed(0) for t in reversed(range(len(scheduler))):
sampled_out = sampler.sample(model, batch_size=1) # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
# ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[-0.0598, -0.0611, -0.0506], [-0.0726, 0.0220, 0.0103], [-0.0723, -0.1310, -0.2458]])).abs().sum() < 1e-3
def test_sample_fast(self):
# 1. Load models
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise
torch.manual_seed(0) torch.manual_seed(0)
sampled_out_new = new_sampler.sample(model, batch_size=1) image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator)
assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output" # 3. Denoise
for t in reversed(range(len(scheduler))):
def test_from_pretrained_hub(self): # i) define coefficients for time step t
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") clip_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
model = self.dummy_model clip_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t))
sampled_out = sampler.sample(model, batch_size=1) clip_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
assert sampled_out is not None, "Make sure output is not None" # ii) predict noise residual
with torch.no_grad():
noise_residual = model(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment