Commit c7a39d38 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor all sinus embeddings

parent 02a76c2c
...@@ -11,15 +11,16 @@ ...@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import math import math
import numpy as np
import numpy as np
import torch
from torch import nn from torch import nn
import torch.nn.functional as F
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000): def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
...@@ -31,16 +32,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down ...@@ -31,16 +32,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
:param max_period: controls the minimum frequency of the embeddings. :param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings. :return: an [N x dim] Tensor of positional embeddings.
""" """
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift))
emb = emb.to(device=timesteps.device) emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = torch.exp(emb * emb_coeff)
emb = timesteps[:, None].float() * emb[None, :] emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings # concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings # flip sine and cosine embeddings
if flip_sin_to_cos: if flip_sin_to_cos:
...@@ -52,81 +57,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down ...@@ -52,81 +57,20 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
return emb return emb
#def get_timestep_embedding(timesteps, embedding_dim): # unet_sde_score_estimation.py
# """ class GaussianFourierProjection(nn.Module):
# This matches the implementation in Denoising Diffusion Probabilistic Models: """Gaussian Fourier embeddings for noise levels."""
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000): def __init__(self, embedding_size=256, scale=1.0):
device = x.device super().__init__()
half_dim = self.dim // 2 self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) def forward(self, x):
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return emb
# unet_rl.py # unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -140,16 +84,3 @@ class SinusoidalPosEmb(nn.Module): ...@@ -140,16 +84,3 @@ class SinusoidalPosEmb(nn.Module):
emb = x[:, None] * emb[None, :] emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb
# unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
...@@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin ...@@ -33,27 +33,6 @@ from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
#def get_timestep_embedding(timesteps, embedding_dim):
# """
# This matches the implementation in Denoising Diffusion Probabilistic Models:
# From Fairseq.
# Build sinusoidal embeddings.
# This matches the implementation in tensor2tensor, but differs slightly
# from the description in Section 3.5 of "Attention Is All You Need".
# """
# assert len(timesteps.shape) == 1
#
# half_dim = embedding_dim // 2
# emb = math.log(10000) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# emb = emb.to(device=timesteps.device)
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
......
...@@ -87,27 +87,6 @@ def normalization(channels, swish=0.0): ...@@ -87,27 +87,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
# def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
def zero_module(module): def zero_module(module):
""" """
Zero out the parameters of a module and return it. Zero out the parameters of a module and return it.
...@@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -628,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
hs = [] hs = []
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
...@@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel): ...@@ -715,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def forward(self, x, timesteps, transformer_out=None): def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
# project the last token # project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_proj = self.transformer_proj(transformer_out[:, -1])
...@@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -807,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
x = torch.cat([x, upsampled], dim=1) x = torch.cat([x, upsampled], dim=1)
hs = [] hs = []
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
h = x h = x
for module in self.input_blocks: for module in self.input_blocks:
......
import math
import torch import torch
...@@ -11,6 +9,7 @@ except: ...@@ -11,6 +9,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -107,21 +106,6 @@ class Residual(torch.nn.Module): ...@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
return output return output
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNetGradTTSModel(ModelMixin, ConfigMixin): class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
...@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
) )
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
...@@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -198,8 +181,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)
if self.n_spks < 2: if self.n_spks < 2:
......
...@@ -317,27 +317,6 @@ def normalization(channels, swish=0.0): ...@@ -317,27 +317,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
## go ## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
...@@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module): ...@@ -1232,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)) emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
)
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
......
...@@ -382,23 +382,6 @@ def get_act(nonlinearity): ...@@ -382,23 +382,6 @@ def get_act(nonlinearity):
raise NotImplementedError("activation function does not exist!") raise NotImplementedError("activation function does not exist!")
#def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
def default_init(scale=1.0): def default_init(scale=1.0):
"""The same initialization used in DDPM.""" """The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale scale = 1e-10 if scale == 0 else scale
......
...@@ -21,8 +21,7 @@ import unittest ...@@ -21,8 +21,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
#from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding, a_get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.embeddings import get_timestep_embedding, timestep_embedding
from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.testing_utils import floats_tensor, slow, torch_device
...@@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False ...@@ -30,15 +29,87 @@ torch.backends.cuda.matmul.allow_tf32 = False
class EmbeddingsTests(unittest.TestCase): class EmbeddingsTests(unittest.TestCase):
def test_timestep_embeddings(self): def test_timestep_embeddings(self):
embedding_dim = 256
timesteps = torch.arange(16)
t1 = get_timestep_embedding(timesteps, embedding_dim)
# first vector should always be composed only of 0's and 1's
assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5
assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5
# last element of each vector should be one
assert (t1[:, -1] - 1).abs().sum() < 1e-5
# For large embeddings (e.g. 128) the frequency of every vector is higher
# than the previous one which means that the gradients of later vectors are
# ALWAYS higher than the previous ones
grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1)
prev_grad = 0.0
for grad in grad_mean:
assert grad > prev_grad
prev_grad = grad
def test_timestep_defaults(self):
embedding_dim = 16 embedding_dim = 16
timesteps = torch.arange(10) timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim) t1 = get_timestep_embedding(timesteps, embedding_dim)
t2 = timestep_embedding(timesteps, embedding_dim) t2 = get_timestep_embedding(
t3 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True, downscale_freq_factor=8) timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
)
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
def test_timestep_flip_sin_cos(self):
embedding_dim = 16
timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True)
t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1)
t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False)
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
def test_timestep_downscale_freq_shift(self):
embedding_dim = 16
timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0)
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1)
# get cosine half (vectors that are wrapped into cosine)
cosine_half = (t1 - t2)[:, embedding_dim // 2 :]
# cosine needs to be negative
assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5
import ipdb; ipdb.set_trace() def test_sinoid_embeddings_hardcoded(self):
embedding_dim = 64
timesteps = torch.arange(128)
# standard unet, score_vde
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False)
# glide, ldm
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True)
# grad-tts
t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000)
assert torch.allclose(
t1[23:26, 47:50].flatten().cpu(),
torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]),
1e-3,
)
assert torch.allclose(
t2[23:26, 47:50].flatten().cpu(),
torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]),
1e-3,
)
assert torch.allclose(
t3[23:26, 47:50].flatten().cpu(),
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
1e-3,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment