"tools/chunk_graph.py" did not exist on "60bc0b7692c6733dd930cd4a502f4171720da38d"
Commit b02d0d6b authored by Patrick von Platen's avatar Patrick von Platen
Browse files

merge

parents 49257b4a 02cdd683
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Denoising Diffusion Implicit Models (DDIM)
## Overview
DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon*
The abstract from the paper is the following:
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
Tips:
- ...
- ...
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
#!/usr/bin/env python3
import os
import pathlib
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id)
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
ddpm = DDIM.from_pretrained("fusing/" + model_id)
image = ddpm(batch_size=4)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
for i in range(image_processed.shape[0]):
image_pil = PIL.Image.fromarray(image_processed[i])
image_pil.save(os.path.join(path, f"image_{i}.png"))
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch
class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
# compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
# model forward
with torch.no_grad():
noise_residual = self.unet(image, train_step)
# predict mean of prev image
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
pred_mean = torch.clamp(pred_mean, -1, 1)
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
image = pred_mean + coeff_1 * noise
else:
image = pred_mean
return image
#!/usr/bin/env python3
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape # (4, 3, 128, 128)
#!/usr/bin/env python3
# !pip install diffusers
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom"
# load model and scheduler
ddpm = DDIM.from_pretrained(model_id)
# run pipeline in inference (sample random noise and denoise)
image = ddpm()
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
image_pil.save("/home/patrick/images/show.png")
...@@ -21,8 +21,6 @@ import torch ...@@ -21,8 +21,6 @@ import torch
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
modeling_file = "modeling_ddpm.py"
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......
# References
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)
\ No newline at end of file
import argparse
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPTextConfig, CLIPTextModel, GPT2Tokenizer from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from modeling_glide import GLIDE
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu") state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()} state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig( config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512, hidden_size=512,
intermediate_size=2048, intermediate_size=2048,
num_hidden_layers=16, num_hidden_layers=16,
num_attention_heads=8, num_attention_heads=8,
max_position_embeddings=128 use_padding_embeddings=True,
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/vocab.json", "./glide-base/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
tokenizer.save_pretrained("./glide-base")
hf_encoder = model.text_model hf_encoder = model.text_model
...@@ -30,15 +35,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"] ...@@ -30,15 +35,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers): for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx] hf_layer = hf_encoder.encoder.layers[layer_idx]
q_proj, k_proj, v_proj = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"].chunk(3, dim=0) hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
q_proj_bias, k_proj_bias, v_proj_bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"].chunk(3, dim=0) hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.q_proj.weight.data = q_proj
hf_layer.self_attn.q_proj.bias.data = q_proj_bias
hf_layer.self_attn.k_proj.weight.data = k_proj
hf_layer.self_attn.k_proj.bias.data = k_proj_bias
hf_layer.self_attn.v_proj.weight.data = v_proj
hf_layer.self_attn.v_proj.bias.data = v_proj_bias
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"] hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"] hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
...@@ -53,8 +51,56 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -53,8 +51,56 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"] hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"] hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt") ### Convert the Text-to-Image UNet
with torch.no_grad():
outputs = model(**inputs) text2im_model = GLIDETextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
transformer_dim=512,
)
text2im_model.load_state_dict(state_dict, strict=False)
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide.save_pretrained("./glide-base")
model.save_pretrained("./glide-base")
\ No newline at end of file
...@@ -14,46 +14,215 @@ ...@@ -14,46 +14,215 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline import numpy as np
from diffusers import UNetGLIDEModel import torch
import tqdm import tqdm
import torch from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GLIDE(DiffusionPipeline): class GLIDE(DiffusionPipeline):
def __init__(self, unet: UNetGLIDEModel, noise_scheduler): def __init__(
self,
text_unet: GLIDETextToImageUNetModel,
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
)
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
B, C = x.shape[:2]
assert t.shape == (B,)
if transformer_out is None:
# super-res model
model_output = model(x, t, low_res)
else:
# text2image model
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
def __call__(self, generator=None, torch_device=None): assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device) self.text_unet.to(torch_device)
self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device)
# Create a classifier-free guidance sampling function
guidance_scale = 3.0
def text_model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) batch_size = 2 # second image is empty for classifier-free guidance
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): image = self.text_noise_scheduler.sample_noise(
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"].to(torch_device)
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
# 3. Run the text2image generation step
num_timesteps = len(self.text_noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
t = torch.tensor([i] * image.shape[0], device=torch_device)
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
# 4. Run the upscaling step
batch_size = 1
image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
eta = 0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
# i) define coefficients for time step t # i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): time_input = torch.tensor([t] * image.shape[0], device=torch_device)
noise_residual = self.unet(image, t) model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# iii) compute predicted image from residual # iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1) pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image image = sampled_prev_image
image = image[0].permute(1, 2, 0)
return image return image
import torch import torch
from .modeling_glide import GLIDE from diffusers import DiffusionPipeline
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler import PIL.Image
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
# 1. Load models model_id = "fusing/glide-base"
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base") # load model and scheduler
model = UNetGLIDEModel.from_pretrained("fusing/glide-base") pipeline = DiffusionPipeline.from_pretrained(model_id)
pipeline = GLIDE(model, scheduler) # run inference (text-conditioned denoising + upscaling)
img = pipeline("a clip art of a hugging face", generator)
img = pipeline(generator) # process image to PIL
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img)
print(img) # save image
image_pil.save("test.png")
\ No newline at end of file
...@@ -5,7 +5,12 @@ ...@@ -5,7 +5,12 @@
__version__ = "0.0.1" __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.clip_text_transformer import CLIPTextModel
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import UNetGLIDEModel from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
...@@ -215,6 +215,7 @@ class ConfigMixin: ...@@ -215,6 +215,7 @@ class ConfigMixin:
init_dict[key] = config_dict.pop(key) init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs) unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys()) passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0: if len(expected_keys - passed_keys) > 0:
logger.warn( logger.warn(
......
...@@ -16,5 +16,8 @@ ...@@ -16,5 +16,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .clip_text_transformer import CLIPTextModel
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import UNetGLIDEModel from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .unet_ldm import UNetLDMModel
from .vqvae import VQModel
\ No newline at end of file
This diff is collapsed.
...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module): ...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class UNetGLIDEModel(ModelMixin, ConfigMixin): class GLIDEUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
...@@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -419,11 +419,11 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
in_channels, in_channels=3,
model_channels, model_channels=192,
out_channels, out_channels=6,
num_res_blocks, num_res_blocks=3,
attention_resolutions, attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
...@@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -435,28 +435,9 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
encoder_channels=None, transformer_dim=None,
): ):
super().__init__() super().__init__()
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
encoder_channels=encoder_channels,
)
if num_heads_upsample == -1: if num_heads_upsample == -1:
num_heads_upsample = num_heads num_heads_upsample = num_heads
...@@ -470,7 +451,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -470,7 +451,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.channel_mult = channel_mult self.channel_mult = channel_mult
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32 # self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
...@@ -508,7 +489,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -508,7 +489,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
num_heads=num_heads, num_heads=num_heads,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=encoder_channels, encoder_channels=transformer_dim,
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
...@@ -551,7 +532,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -551,7 +532,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
num_heads=num_heads, num_heads=num_heads,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=encoder_channels, encoder_channels=transformer_dim,
), ),
ResBlock( ResBlock(
ch, ch,
...@@ -587,7 +568,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -587,7 +568,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=encoder_channels, encoder_channels=transformer_dim,
) )
) )
if level and i == num_res_blocks: if level and i == num_res_blocks:
...@@ -633,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -633,7 +614,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps, transformer_out): def forward(self, x, timesteps):
""" """
Apply the model to an input batch. Apply the model to an input batch.
...@@ -642,18 +623,184 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -642,18 +623,184 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional. :param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
# project the last token # project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_proj = self.transformer_proj(transformer_out[:, -1])
transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL transformer_out = transformer_out.permute(0, 2, 1) # NLC -> NCL
h = x.type(self.dtype) emb = emb + transformer_proj.to(emb)
h = x
for module in self.input_blocks:
h = module(h, emb, transformer_out)
hs.append(h)
h = self.middle_block(h, emb, transformer_out)
for module in self.output_blocks:
other = hs.pop()
h = torch.cat([h, other], dim=1)
h = module(h, emb, transformer_out)
return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(
self,
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
self.register(
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
)
def forward(self, x, timesteps, low_res=None):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1)
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
h = x
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb) h = module(h, emb)
hs.append(h) hs.append(h)
...@@ -661,5 +808,5 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin): ...@@ -661,5 +808,5 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
for module in self.output_blocks: for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb) h = module(h, emb)
h = h.type(x.dtype)
return self.out(h) return self.out(h)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import importlib import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .utils import logging, DIFFUSERS_CACHE from .utils import logging, DIFFUSERS_CACHE
...@@ -34,10 +35,13 @@ logger = logging.get_logger(__name__) ...@@ -34,10 +35,13 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"], "ModelMixin": ["save_pretrained", "from_pretrained"],
"CLIPTextModel": ["save_pretrained", "from_pretrained"], # TODO (Anton): move to transformers
"GaussianDDPMScheduler": ["save_config", "from_config"], "GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"ModelMixin": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
}, },
} }
...@@ -50,6 +54,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -50,6 +54,10 @@ class DiffusionPipeline(ConfigMixin):
for name, module in kwargs.items(): for name, module in kwargs.items():
# retrive library # retrive library
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# if library is not in LOADABLE_CLASSES, then it is a custom module
if library not in LOADABLE_CLASSES:
library = module.__module__.split(".")[-1]
# retrive class_name # retrive class_name
class_name = module.__class__.__name__ class_name = module.__class__.__name__
...@@ -61,7 +69,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -61,7 +69,7 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} register_dict = {"_module": self.__module__.split(".")[-1] + ".py"}
self.register(**register_dict) self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
...@@ -123,25 +131,40 @@ class DiffusionPipeline(ConfigMixin): ...@@ -123,25 +131,40 @@ class DiffusionPipeline(ConfigMixin):
module = config_dict["_module"] module = config_dict["_module"]
class_name_ = config_dict["_class_name"] class_name_ = config_dict["_class_name"]
module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "")
if class_name_ == cls.__name__: # if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls pipeline_class = cls
else: else:
# else we need to load the correct module from the Hub
class_name_ = config_dict["_class_name"]
module = module_candidate
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
for name, (library_name, class_name) in init_dict.items(): # get all importable classes to get the load method name for custom models/components
importable_classes = LOADABLE_CLASSES[library_name] # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
all_importable_classes = {}
if library_name == module: for library in LOADABLE_CLASSES:
# TODO(Suraj) all_importable_classes.update(LOADABLE_CLASSES[library])
pass
for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
if library_name == module_candidate_name:
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = all_importable_classes
class_candidates = {c: class_obj for c in all_importable_classes}
else:
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
load_method_name = None load_method_name = None
......
...@@ -16,4 +16,6 @@ ...@@ -16,4 +16,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .gaussian_ddpm import GaussianDDPMScheduler from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment