# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import math import numpy as np from torch import nn import torch.nn.functional as F def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift)) emb = emb.to(device=timesteps.device) emb = timesteps[:, None].float() * emb[None, :] # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb #def get_timestep_embedding(timesteps, embedding_dim): # """ # This matches the implementation in Denoising Diffusion Probabilistic Models: # From Fairseq. # Build sinusoidal embeddings. # This matches the implementation in tensor2tensor, but differs slightly # from the description in Section 3.5 of "Attention Is All You Need". # """ # assert len(timesteps.shape) == 1 # # half_dim = embedding_dim // 2 # emb = math.log(10000) / (half_dim - 1) # emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) # emb = emb.to(device=timesteps.device) # emb = timesteps.float()[:, None] * emb[None, :] # emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # if embedding_dim % 2 == 1: # zero pad # emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) #def timestep_embedding(timesteps, dim, max_period=10000): # """ # Create sinusoidal timestep embeddings. # # :param timesteps: a 1-D Tensor of N indices, one per batch element. # These may be fractional. # :param dim: the dimension of the output. # :param max_period: controls the minimum frequency of the embeddings. # :return: an [N x dim] Tensor of positional embeddings. # """ # half = dim // 2 # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( # device=timesteps.device # ) # args = timesteps[:, None].float() * freqs[None, :] # embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # if dim % 2: # embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # return embedding #def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): # assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 # half_dim = embedding_dim // 2 # magic number 10000 is from transformers # emb = math.log(max_positions) / (half_dim - 1) # emb = math.log(2.) / (half_dim - 1) # emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] # emb = timesteps.float()[:, None] * emb[None, :] # emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # if embedding_dim % 2 == 1: # zero pad # emb = F.pad(emb, (0, 1), mode="constant") # assert emb.shape == (timesteps.shape[0], embedding_dim) # return emb # unet_grad_tts.py class SinusoidalPosEmb(torch.nn.Module): def __init__(self, dim): super(SinusoidalPosEmb, self).__init__() self.dim = dim def forward(self, x, scale=1000): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # unet_rl.py class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb # unet_sde_score_estimation.py class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__(self, embedding_size=256, scale=1.0): super().__init__() self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) def forward(self, x): x_proj = x[:, None] * self.W[None, :] * 2 * np.pi return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)