Commit 0e13d329 authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	tests/test_modeling_utils.py
parents 3f9e3d8a e13ee8b5
...@@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
#### **Example 1024x1024 image generation with SDE VE**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=2000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Example 32x32 image generation with SDE VP**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=1000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Text to Image generation with Latent Diffusion** #### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ _Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
...@@ -249,24 +299,24 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -249,24 +299,24 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
#### **Text to speech with GradTTS and BDDM** #### **Text to speech with GradTTS and BDDMPipeline**
```python ```python
import torch import torch
from diffusers import BDDM, GradTTS from diffusers import BDDMPipeline, GradTTSPipeline
torch_device = "cuda" torch_device = "cuda"
# load grad tts and bddm pipelines # load grad tts and bddm pipelines
grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts") grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts")
bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech") bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech")
text = "Hello world, I missed you so much." text = "Hello world, I missed you so much."
# generate mel spectograms using text # generate mel spectograms using text
mel_spec = grad_tts(text, torch_device=torch_device) mel_spec = grad_tts(text, torch_device=torch_device)
# generate the speech by passing mel spectograms to BDDM pipeline # generate the speech by passing mel spectograms to BDDMPipeline pipeline
generator = torch.manual_seed(42) generator = torch.manual_seed(42)
audio = bddm(mel_spec, generator, torch_device=torch_device) audio = bddm(mel_spec, generator, torch_device=torch_device)
...@@ -288,3 +338,4 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) ...@@ -288,3 +338,4 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
- [ ] Add more vision models - [ ] Add more vision models
- [ ] Add more speech models - [ ] Add more speech models
- [ ] Add RL model - [ ] Add RL model
- [ ] Add FID and KID metrics
#!/usr/bin/env python3
import numpy as np
import PIL
import torch
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NewReverseDiffusionPredictor:
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.probability_flow = probability_flow
self.score_fn = score_fn
def discretize(self, x, t):
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
result = self.score_fn(x, labels)
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
def update_fn(self, x, t):
f, G = self.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
class NewLangevinCorrector:
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
super().__init__()
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
def update_fn(self, x, t):
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
# timestep = (t * (sde.N - 1) / sde.T).long()
# alpha = sde.alphas.to(t.device)[timestep]
# else:
alpha = torch.ones_like(t)
for i in range(n_steps):
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
grad = score_fn(x, labels)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
return x, x_mean
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
N = 2
sigma_min = 0.01
sigma_max = 1348
sampling_eps = 1e-5
batch_size = 1
centered = False
from diffusers import NCSNpp
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
model = torch.nn.DataParallel(model)
img_size = model.module.config.image_size
channels = model.module.config.num_channels
shape = (batch_size, channels, img_size, img_size)
probability_flow = False
snr = 0.15
n_steps = 1
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
with torch.no_grad():
# Initial sample
x = torch.randn(*shape) * sigma_max
x = x.to(device)
timesteps = torch.linspace(1, sampling_eps, N, device=device)
for i in range(N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x = x_mean
if centered:
x = (x + 1.) / 2.
# save_image(x)
# for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758
# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504
# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
...@@ -7,23 +7,29 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode ...@@ -7,23 +7,29 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, PNDM from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import (
DDIMScheduler,
DDPMScheduler,
GradTTSScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
)
if is_transformers_available(): if is_transformers_available():
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import Glide, LatentDiffusion from .pipelines import GlidePipeline, LatentDiffusionPipeline
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
if is_transformers_available() and is_inflect_available() and is_unidecode_available(): if is_transformers_available() and is_inflect_available() and is_unidecode_available():
from .pipelines import GradTTS from .pipelines import GradTTSPipeline
else: else:
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import * from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
...@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide ...@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet from .unet_rl import TemporalUNet
from .unet_sde_score_estimation import NCSNpp
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
from torch import nn
def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = torch.exp(emb * emb_coeff)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
import torch
import torch.nn as nn
import torch.nn.functional as F
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.ConvTranspose1d(*args, **kwargs)
elif dims == 2:
return nn.ConvTranspose2d(*args, **kwargs)
elif dims == 3:
return nn.ConvTranspose3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def nonlinearity(x, swish=1.0):
# swish
if swish == 1.0:
return F.silu(x)
else:
return x * F.sigmoid(x * float(swish))
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.padding = padding
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0 and self.dims == 2:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.down(x)
class UNetUpsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class LDMUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class GradTTSUpsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
# class ResnetBlock(nn.Module):
# def __init__(
# self,
# *,
# in_channels,
# out_channels=None,
# conv_shortcut=False,
# dropout,
# temb_channels=512,
# use_scale_shift_norm=False,
# ):
# super().__init__()
# self.in_channels = in_channels
# out_channels = in_channels if out_channels is None else out_channels
# self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
# self.use_scale_shift_norm = use_scale_shift_norm
# self.norm1 = Normalize(in_channels)
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles)
# self.norm2 = Normalize(out_channels)
# self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# else:
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# def forward(self, x, temb):
# h = x
# h = self.norm1(h)
# h = nonlinearity(h)
# h = self.conv1(h)
# # TODO: check if this broadcasting works correctly for 1D and 3D
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(temb, 2, dim=1)
# h = self.norm2(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + temb
# h = self.norm2(h)
# h = nonlinearity(h)
# h = self.dropout(h)
# h = self.conv2(h)
# if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
# x = self.nin_shortcut(x)
# return x + h
...@@ -30,27 +30,7 @@ from tqdm import tqdm ...@@ -30,27 +30,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x): def nonlinearity(x):
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -86,27 +87,6 @@ def normalization(channels, swish=0.0): ...@@ -86,27 +87,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def zero_module(module): def zero_module(module):
""" """
Zero out the parameters of a module and return it. Zero out the parameters of a module and return it.
...@@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
...@@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel): ...@@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def forward(self, x, timesteps, transformer_out=None): def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
# project the last token # project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_proj = self.transformer_proj(transformer_out[:, -1])
...@@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x = torch.cat([x, upsampled], dim=1) x = torch.cat([x, upsampled], dim=1)
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x h = x
for module in self.input_blocks: for module in self.input_blocks:
......
import math
import torch import torch
...@@ -11,6 +9,7 @@ except: ...@@ -11,6 +9,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -107,21 +106,6 @@ class Residual(torch.nn.Module): ...@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return output return output
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNetGradTTSModel(ModelMixin, ConfigMixin): class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
...@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
) )
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
...@@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)
if self.n_spks < 2: if self.n_spks < 2:
......
...@@ -16,6 +16,7 @@ except: ...@@ -16,6 +16,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def exists(val): def exists(val):
...@@ -316,36 +317,6 @@ def normalization(channels, swish=0.0): ...@@ -316,36 +317,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
## go ## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
...@@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs = [] hs = []
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = timestep_embedding(timesteps, self.model_channels) t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
...@@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module): ...@@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
......
...@@ -13,7 +13,6 @@ except: ...@@ -13,7 +13,6 @@ except:
print("Einops is not installed") print("Einops is not installed")
pass pass
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -107,14 +106,21 @@ class ResidualTemporalBlock(nn.Module): ...@@ -107,14 +106,21 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
horizon, training_horizon,
transition_dim, transition_dim,
cond_dim, cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32, dim=32,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
): ):
super().__init__() super().__init__()
self.transition_dim = transition_dim
self.cond_dim = cond_dim
self.predict_epsilon = predict_epsilon
self.clip_denoised = clip_denoised
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}') # print(f'[ models/temporal ] Channel dimensions: {in_out}')
...@@ -138,19 +144,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -138,19 +144,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.downs.append( self.downs.append(
nn.ModuleList( nn.ModuleList(
[ [
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(), Downsample1d(dim_out) if not is_last else nn.Identity(),
] ]
) )
) )
if not is_last: if not is_last:
horizon = horizon // 2 training_horizon = training_horizon // 2
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
...@@ -158,15 +164,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -158,15 +164,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups.append( self.ups.append(
nn.ModuleList( nn.ModuleList(
[ [
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(), Upsample1d(dim_in) if not is_last else nn.Identity(),
] ]
) )
) )
if not is_last: if not is_last:
horizon = horizon * 2 training_horizon = training_horizon * 2
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5), Conv1dBlock(dim, dim, kernel_size=5),
...@@ -232,7 +238,6 @@ class TemporalValue(nn.Module): ...@@ -232,7 +238,6 @@ class TemporalValue(nn.Module):
print(in_out) print(in_out)
for dim_in, dim_out in in_out: for dim_in, dim_out in in_out:
self.blocks.append( self.blocks.append(
nn.ModuleList( nn.ModuleList(
[ [
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
import functools
import string
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
# Function ported from StyleGAN2
def get_weight(module, shape, weight_var="weight", kernel_init=None):
"""Get/create weight tensor for a convolution or fully-connected layer."""
return module.param(weight_var, kernel_init, shape)
class Conv2d(nn.Module):
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
def __init__(
self,
in_ch,
out_ch,
kernel,
up=False,
down=False,
resample_kernel=(1, 3, 3, 1),
use_bias=True,
kernel_init=None,
):
super().__init__()
assert not (up and down)
assert kernel >= 1 and kernel % 2 == 1
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
if kernel_init is not None:
self.weight.data = kernel_init(self.weight.data.shape)
if use_bias:
self.bias = nn.Parameter(torch.zeros(out_ch))
self.up = up
self.down = down
self.resample_kernel = resample_kernel
self.kernel = kernel
self.use_bias = use_bias
def forward(self, x):
if self.up:
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
elif self.down:
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
else:
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
if self.use_bias:
x = x + self.bias.reshape(1, -1, 1, 1)
return x
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
Padding is performed only once at the beginning, not between the
operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
# Check weight shape.
assert len(w.shape) == 4
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
assert convW == convH
# Setup filter kernel.
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
output_padding = (
output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
num_groups = _shape(x, 1) // inC
# Transpose weights.
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
# Original TF code.
# x = tf.nn.conv2d_transpose(
# x,
# w,
# output_shape=output_shape,
# strides=stride,
# padding='VALID',
# data_format=data_format)
# JAX equivalent
return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
_outC, _inC, convH, convW = w.shape
assert convW == convH
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2))
return F.conv2d(x, w, stride=s, padding=0)
def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
assert k.ndim == 2
assert k.shape[0] == k.shape[1]
return k
def _shape(x, dim):
return x.shape[dim]
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and upsamples each image with the given filter. The filter is normalized so
that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the upsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and downsamples each image with the given filter. The filter is normalized
so that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the downsampling factor.
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = _setup_kernel(k) * gain
p = k.shape[0] - factor
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
conv1x1 = ddpm_conv1x1
conv3x3 = ddpm_conv3x3
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def get_act(nonlinearity):
"""Get activation functions from the config file."""
if nonlinearity.lower() == "elu":
return nn.ELU()
elif nonlinearity.lower() == "relu":
return nn.ReLU()
elif nonlinearity.lower() == "lrelu":
return nn.LeakyReLU(negative_slope=0.2)
elif nonlinearity.lower() == "swish":
return nn.SiLU()
else:
raise NotImplementedError("activation function does not exist!")
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX."""
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
self.Conv_0 = conv1x1(dim1, dim2)
self.method = method
def forward(self, x, y):
h = self.Conv_0(x)
if self.method == "cat":
return torch.cat([h, y], dim=1)
elif self.method == "sum":
return h + y
else:
raise ValueError(f"Method {self.method} not recognized.")
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
up=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.with_conv = with_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
def forward(self, x):
B, C, H, W = x.shape
if not self.fir:
h = F.interpolate(x, (H * 2, W * 2), "nearest")
if self.with_conv:
h = self.Conv_0(h)
else:
if not self.with_conv:
h = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = self.Conv2d_0(x)
return h
class Downsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
down=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
def forward(self, x):
B, C, H, W = x.shape
if not self.fir:
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1))
x = self.Conv_0(x)
else:
x = F.avg_pool2d(x, 2, stride=2)
else:
if not self.with_conv:
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
x = self.Conv2d_0(x)
return x
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class ResnetBlockBigGANpp(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
if self.up:
if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down:
if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model"""
def __init__(
self,
centered=False,
image_size=1024,
num_channels=3,
attention_type="ddpm",
attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True,
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
nf=16,
nonlinearity="swish",
normalization="GroupNorm",
num_res_blocks=1,
progressive="output_skip",
progressive_combine="sum",
progressive_input="input_skip",
resamp_with_conv=True,
resblock_type="biggan",
scale_by_sigma=True,
skip_rescale=True,
continuous=True,
):
super().__init__()
self.register_to_config(
centered=centered,
image_size=image_size,
num_channels=num_channels,
attention_type=attention_type,
attn_resolutions=attn_resolutions,
ch_mult=ch_mult,
conditional=conditional,
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
nf=nf,
nonlinearity=nonlinearity,
normalization=normalization,
num_res_blocks=num_res_blocks,
progressive=progressive,
progressive_combine=progressive_combine,
progressive_input=progressive_input,
resamp_with_conv=resamp_with_conv,
resblock_type=resblock_type,
scale_by_sigma=scale_by_sigma,
skip_rescale=skip_rescale,
continuous=continuous,
)
self.act = act = get_act(nonlinearity)
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
self.conditional = conditional
self.skip_rescale = skip_rescale
self.resblock_type = resblock_type
self.progressive = progressive
self.progressive_input = progressive_input
self.embedding_type = embedding_type
assert progressive in ["none", "output_skip", "residual"]
assert progressive_input in ["none", "input_skip", "residual"]
assert embedding_type in ["fourier", "positional"]
combine_method = progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)
modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == "fourier":
# Gaussian Fourier features embeddings.
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
embed_dim = 2 * nf
elif embedding_type == "positional":
embed_dim = nf
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
if conditional:
modules.append(nn.Linear(embed_dim, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale)
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
if resblock_type == "ddpm":
ResnetBlock = functools.partial(
ResnetBlockDDPMpp,
act=act,
dropout=dropout,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
elif resblock_type == "biggan":
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir=fir,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
else:
raise ValueError(f"resblock type {resblock_type} unrecognized.")
# Downsampling block
channels = num_channels
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(conv3x3(channels, nf))
hs_c = [nf]
in_ch = nf
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
if resblock_type == "ddpm":
modules.append(Downsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(down=True, in_ch=in_ch))
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
if combine_method == "cat":
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
in_ch = hs_c[-1]
modules.append(ResnetBlock(in_ch=in_ch))
modules.append(AttnBlock(channels=in_ch))
modules.append(ResnetBlock(in_ch=in_ch))
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
if progressive != "none":
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
pyramid_ch = channels
elif progressive == "residual":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, in_ch, bias=True))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
if resblock_type == "ddpm":
modules.append(Upsample(in_ch=in_ch))
else:
modules.append(ResnetBlock(in_ch=in_ch, up=True))
assert not hs_c
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
self.all_modules = nn.ModuleList(modules)
def forward(self, x, time_cond, sigmas=None):
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules
m_idx = 0
if self.embedding_type == "fourier":
# Gaussian Fourier features embeddings.
used_sigmas = time_cond
temb = modules[m_idx](torch.log(used_sigmas))
m_idx += 1
elif self.embedding_type == "positional":
# Sinusoidal positional embeddings.
timesteps = time_cond
used_sigmas = sigmas
temb = get_timestep_embedding(timesteps, self.nf)
else:
raise ValueError(f"embedding type {self.embedding_type} unknown.")
if self.conditional:
temb = modules[m_idx](temb)
m_idx += 1
temb = modules[m_idx](self.act(temb))
m_idx += 1
else:
temb = None
if not self.config.centered:
# If input data is in [0, 1]
x = 2 * x - 1.0
# Downsampling block
input_pyramid = None
if self.progressive_input != "none":
input_pyramid = x
hs = [modules[m_idx](x)]
m_idx += 1
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(self.num_res_blocks):
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
hs.append(h)
if i_level != self.num_resolutions - 1:
if self.resblock_type == "ddpm":
h = modules[m_idx](hs[-1])
m_idx += 1
else:
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if self.progressive_input == "input_skip":
input_pyramid = self.pyramid_downsample(input_pyramid)
h = modules[m_idx](input_pyramid, h)
m_idx += 1
elif self.progressive_input == "residual":
input_pyramid = modules[m_idx](input_pyramid)
m_idx += 1
if self.skip_rescale:
input_pyramid = (input_pyramid + h) / np.sqrt(2.0)
else:
input_pyramid = input_pyramid + h
h = input_pyramid
hs.append(h)
h = hs[-1]
h = modules[m_idx](h, temb)
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
h = modules[m_idx](h, temb)
m_idx += 1
pyramid = None
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if self.progressive != "none":
if i_level == self.num_resolutions - 1:
if self.progressive == "output_skip":
pyramid = self.act(modules[m_idx](h))
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
elif self.progressive == "residual":
pyramid = self.act(modules[m_idx](h))
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
else:
raise ValueError(f"{self.progressive} is not a valid name.")
else:
if self.progressive == "output_skip":
pyramid = self.pyramid_upsample(pyramid)
pyramid_h = self.act(modules[m_idx](h))
m_idx += 1
pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1
pyramid = pyramid + pyramid_h
elif self.progressive == "residual":
pyramid = modules[m_idx](pyramid)
m_idx += 1
if self.skip_rescale:
pyramid = (pyramid + h) / np.sqrt(2.0)
else:
pyramid = pyramid + h
h = pyramid
else:
raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0:
if self.resblock_type == "ddpm":
h = modules[m_idx](h)
m_idx += 1
else:
h = modules[m_idx](h, temb)
m_idx += 1
assert not hs
if self.progressive == "output_skip":
h = pyramid
else:
h = self.act(modules[m_idx](h))
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
assert m_idx == len(modules)
if self.config.scale_by_sigma:
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
h = h / used_sigmas
return h
...@@ -21,7 +21,6 @@ from typing import Optional, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, logging
...@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1]}
self.register_to_config(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory) self.save_config(save_directory)
model_index_dict = dict(self.config) model_index_dict = dict(self.config)
model_index_dict.pop("_class_name") model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version") model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module") model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys(): for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name) sub_model = getattr(self, pipeline_component_name)
...@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
# 2. Get class name and module candidates to load custom models # 2. Load the pipeline class, if using custom module then load it from the hub
module_candidate_name = config_dict["_module"]
module_candidate = module_candidate_name + ".py"
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if cls != DiffusionPipeline: if cls != DiffusionPipeline:
pipeline_class = cls pipeline_class = cls
...@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# (TODO - we should allow to load custom pipelines
# else we need to load the correct module from the Hub
# module = module_candidate
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
...@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin):
# import it here to avoid circular import # import it here to avoid circular import
from diffusers import pipelines from diffusers import pipelines
# 4. Load each module in the pipeline # 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
# if the model is in a pipeline module, then we load it from the pipeline # if the model is in a pipeline module, then we load it from the pipeline
...@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin):
pipeline_module = getattr(pipelines, library_name) pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name) class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} class_candidates = {c: class_obj for c in importable_classes.keys()}
elif library_name == module_candidate_name:
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
else: else:
# else we just import it from the library. # else we just import it from the library.
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
......
...@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj) ...@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj)
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py). - PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py). - Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py). - Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). - BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py). - Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDM from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDM from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
if is_transformers_available(): if is_transformers_available():
from .pipeline_glide import Glide from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusionPipeline
if is_transformers_available() and is_unidecode_available() and is_inflect_available(): if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTS from .pipeline_grad_tts import GradTTSPipeline
...@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin): ...@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin):
return self.final_conv(x) return self.final_conv(x)
class BDDM(DiffusionPipeline): class BDDMPipeline(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler): def __init__(self, diffwave, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -21,7 +21,7 @@ import tqdm ...@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class DDIM(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -21,7 +21,7 @@ import tqdm ...@@ -21,7 +21,7 @@ import tqdm
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
......
...@@ -695,7 +695,23 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -695,7 +695,23 @@ class CLIPTextModel(CLIPPreTrainedModel):
##################### #####################
class Glide(DiffusionPipeline): def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GlidePipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
text_unet: GlideTextToImageUNetModel, text_unet: GlideTextToImageUNetModel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment