Commit 0e13d329 authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	tests/test_modeling_utils.py
parents 3f9e3d8a e13ee8b5
......@@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
#### **Example 1024x1024 image generation with SDE VE**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=2000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Example 32x32 image generation with SDE VP**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=1000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
......@@ -249,24 +299,24 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
#### **Text to speech with GradTTS and BDDM**
#### **Text to speech with GradTTS and BDDMPipeline**
```python
import torch
from diffusers import BDDM, GradTTS
from diffusers import BDDMPipeline, GradTTSPipeline
torch_device = "cuda"
# load grad tts and bddm pipelines
grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts")
bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech")
grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts")
bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech")
text = "Hello world, I missed you so much."
# generate mel spectograms using text
mel_spec = grad_tts(text, torch_device=torch_device)
# generate the speech by passing mel spectograms to BDDM pipeline
# generate the speech by passing mel spectograms to BDDMPipeline pipeline
generator = torch.manual_seed(42)
audio = bddm(mel_spec, generator, torch_device=torch_device)
......@@ -288,3 +338,4 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
- [ ] Add more vision models
- [ ] Add more speech models
- [ ] Add RL model
- [ ] Add FID and KID metrics
#!/usr/bin/env python3
import numpy as np
import PIL
import torch
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NewReverseDiffusionPredictor:
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.probability_flow = probability_flow
self.score_fn = score_fn
def discretize(self, x, t):
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
result = self.score_fn(x, labels)
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
def update_fn(self, x, t):
f, G = self.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
class NewLangevinCorrector:
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
super().__init__()
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
def update_fn(self, x, t):
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
# timestep = (t * (sde.N - 1) / sde.T).long()
# alpha = sde.alphas.to(t.device)[timestep]
# else:
alpha = torch.ones_like(t)
for i in range(n_steps):
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
grad = score_fn(x, labels)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
return x, x_mean
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
N = 2
sigma_min = 0.01
sigma_max = 1348
sampling_eps = 1e-5
batch_size = 1
centered = False
from diffusers import NCSNpp
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
model = torch.nn.DataParallel(model)
img_size = model.module.config.image_size
channels = model.module.config.num_channels
shape = (batch_size, channels, img_size, img_size)
probability_flow = False
snr = 0.15
n_steps = 1
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
with torch.no_grad():
# Initial sample
x = torch.randn(*shape) * sigma_max
x = x.to(device)
timesteps = torch.linspace(1, sampling_eps, N, device=device)
for i in range(N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x = x_mean
if centered:
x = (x + 1.) / 2.
# save_image(x)
# for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758
# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504
# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
......@@ -7,23 +7,29 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, PNDM
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
GradTTSScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
)
if is_transformers_available():
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import Glide, LatentDiffusion
from .pipelines import GlidePipeline, LatentDiffusionPipeline
else:
from .utils.dummy_transformers_objects import *
if is_transformers_available() and is_inflect_available() and is_unidecode_available():
from .pipelines import GradTTS
from .pipelines import GradTTSPipeline
else:
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
......@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet
from .unet_sde_score_estimation import NCSNpp
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
from torch import nn
def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = torch.exp(emb * emb_coeff)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
import torch
import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.down(x)
class UNetUpsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class LDMUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class GradTTSUpsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
# class ResnetBlock(nn.Module):
# def __init__(
# self,
# *,
# in_channels,
# out_channels=None,
# conv_shortcut=False,
# dropout,
# temb_channels=512,
# use_scale_shift_norm=False,
# ):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
# self.use_scale_shift_norm = use_scale_shift_norm
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
# # TODO: check if this broadcasting works correctly for 1D and 3D
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(temb, 2, dim=1)
# h = self.norm2(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + temb
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
# return x + h
......@@ -30,27 +30,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
from .embeddings import get_timestep_embedding
def nonlinearity(x):
......
......@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def convert_module_to_f16(l):
......@@ -86,27 +87,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
......@@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x.type(self.dtype)
for module in self.input_blocks:
......@@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def forward(self, x, timesteps, transformer_out=None):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
# project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1])
......@@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x = torch.cat([x, upsampled], dim=1)
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x
for module in self.input_blocks:
......
import math
import torch
......@@ -11,6 +9,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
class Mish(torch.nn.Module):
......@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return output
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__()
......@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
......@@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
t = self.mlp(t)
if self.n_spks < 2:
......
......@@ -16,6 +16,7 @@ except:
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def exists(val):
......@@ -316,36 +317,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
## go
class AttentionPool2d(nn.Module):
"""
......@@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs = []
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = timestep_embedding(timesteps, self.model_channels)
t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
......@@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
results = []
h = x.type(self.dtype)
......
......@@ -13,7 +13,6 @@ except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
......@@ -107,14 +106,21 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__(
self,
horizon,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
self.transition_dim = transition_dim
self.cond_dim = cond_dim
self.predict_epsilon = predict_epsilon
self.clip_denoised = clip_denoised
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
......@@ -138,19 +144,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon // 2
training_horizon = training_horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
......@@ -158,15 +164,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon * 2
training_horizon = training_horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
......@@ -232,7 +238,6 @@ class TemporalValue(nn.Module):
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(
nn.ModuleList(
[
......
This diff is collapsed.
......@@ -21,7 +21,6 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
......@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1]}
self.register_to_config(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module")
model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
......@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder)
# 2. Get class name and module candidates to load custom models
module_candidate_name = config_dict["_module"]
module_candidate = module_candidate_name + ".py"
# 3. Load the pipeline class, if using custom module then load it from the hub
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
......@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# (TODO - we should allow to load custom pipelines
# else we need to load the correct module from the Hub
# module = module_candidate
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
......@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin):
# import it here to avoid circular import
from diffusers import pipelines
# 4. Load each module in the pipeline
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name)
# if the model is in a pipeline module, then we load it from the pipeline
......@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin):
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
elif library_name == module_candidate_name:
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
......
......@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj)
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
if is_transformers_available():
from .pipeline_glide import Glide
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTS
from .pipeline_grad_tts import GradTTSPipeline
......@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin):
return self.final_conv(x)
class BDDM(DiffusionPipeline):
class BDDMPipeline(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline
class DDIM(DiffusionPipeline):
class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline):
class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
......
......@@ -695,7 +695,23 @@ class CLIPTextModel(CLIPPreTrainedModel):
#####################
class Glide(DiffusionPipeline):
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GlidePipeline(DiffusionPipeline):
def __init__(
self,
text_unet: GlideTextToImageUNetModel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment