Commit e13ee8b5 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents 0027993e 6846ee2a
...@@ -11,47 +11,75 @@ ...@@ -11,47 +11,75 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import numpy as np
import torch
from torch import nn
# unet.py
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq. Create sinusoidal timestep embeddings.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly :param timesteps: a 1-D Tensor of N indices, one per batch element.
from the description in Section 3.5 of "Attention Is All You Need". These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
""" """
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
emb = emb.to(device=timesteps.device) emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :] emb = torch.exp(emb * emb_coeff)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = timesteps[:, None].float() * emb[None, :]
if embedding_dim % 2 == 1: # zero pad
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb return emb
# unet_glide.py # unet_sde_score_estimation.py
def timestep_embedding(timesteps, dim, max_period=10000): class GaussianFourierProjection(nn.Module):
""" """Gaussian Fourier embeddings for noise levels."""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. def __init__(self, embedding_size=256, scale=1.0):
These may be fractional. super().__init__()
:param dim: the dimension of the output. self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings. def forward(self, x):
""" x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
half = dim // 2 return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
) # unet_rl.py - TODO(need test)
args = timesteps[:, None].float() * freqs[None] class SinusoidalPosEmb(nn.Module):
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) def __init__(self, dim):
if dim % 2: super().__init__()
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) self.dim = dim
return embedding
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
...@@ -30,27 +30,7 @@ from tqdm import tqdm ...@@ -30,27 +30,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x): def nonlinearity(x):
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -86,27 +87,6 @@ def normalization(channels, swish=0.0): ...@@ -86,27 +87,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def zero_module(module): def zero_module(module):
""" """
Zero out the parameters of a module and return it. Zero out the parameters of a module and return it.
...@@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
...@@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel): ...@@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def forward(self, x, timesteps, transformer_out=None): def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
# project the last token # project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_proj = self.transformer_proj(transformer_out[:, -1])
...@@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x = torch.cat([x, upsampled], dim=1) x = torch.cat([x, upsampled], dim=1)
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x h = x
for module in self.input_blocks: for module in self.input_blocks:
......
import math
import torch import torch
...@@ -11,6 +9,7 @@ except: ...@@ -11,6 +9,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -107,21 +106,6 @@ class Residual(torch.nn.Module): ...@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return output return output
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNetGradTTSModel(ModelMixin, ConfigMixin): class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
...@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
) )
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
...@@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)
if self.n_spks < 2: if self.n_spks < 2:
......
...@@ -16,6 +16,7 @@ except: ...@@ -16,6 +16,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def exists(val): def exists(val):
...@@ -316,36 +317,6 @@ def normalization(channels, swish=0.0): ...@@ -316,36 +317,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
## go ## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
...@@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs = [] hs = []
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = timestep_embedding(timesteps, self.model_channels) t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
...@@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module): ...@@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# helpers functions # helpers functions
import functools import functools
import math
import string import string
import numpy as np import numpy as np
...@@ -26,6 +25,7 @@ import torch.nn.functional as F ...@@ -26,6 +25,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -381,23 +381,6 @@ def get_act(nonlinearity): ...@@ -381,23 +381,6 @@ def get_act(nonlinearity):
raise NotImplementedError("activation function does not exist!") raise NotImplementedError("activation function does not exist!")
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode="constant")
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
def default_init(scale=1.0): def default_init(scale=1.0):
"""The same initialization used in DDPM.""" """The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale scale = 1e-10 if scale == 0 else scale
...@@ -434,18 +417,6 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -434,18 +417,6 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
return init return init
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment