Commit 02a76c2c authored by Patrick von Platen's avatar Patrick von Platen
Browse files

consolidate timestep embeds

parent 014ebc59
...@@ -11,49 +11,104 @@ ...@@ -11,49 +11,104 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import math
import numpy as np
from torch import nn
import torch.nn.functional as F
# unet.py
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10000):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings. Create sinusoidal timestep embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need". :param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
""" """
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = torch.exp(-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / (embedding_dim // 2 - downscale_freq_shift))
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device) emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :] emb = timesteps[:, None].float() * emb[None, :]
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb return emb
# unet_glide.py
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. #def get_timestep_embedding(timesteps, embedding_dim):
These may be fractional. # """
:param dim: the dimension of the output. # This matches the implementation in Denoising Diffusion Probabilistic Models:
:param max_period: controls the minimum frequency of the embeddings. # From Fairseq.
:return: an [N x dim] Tensor of positional embeddings. # Build sinusoidal embeddings.
""" # This matches the implementation in tensor2tensor, but differs slightly
half = dim // 2 # from the description in Section 3.5 of "Attention Is All You Need".
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( # """
device=timesteps.device # assert len(timesteps.shape) == 1
) #
args = timesteps[:, None].float() * freqs[None] # half_dim = embedding_dim // 2
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # emb = math.log(10000) / (half_dim - 1)
if dim % 2: # emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # emb = emb.to(device=timesteps.device)
return embedding # emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
#def timestep_embedding(timesteps, dim, max_period=10000):
# """
# Create sinusoidal timestep embeddings.
#
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
# These may be fractional.
# :param dim: the dimension of the output.
# :param max_period: controls the minimum frequency of the embeddings.
# :return: an [N x dim] Tensor of positional embeddings.
# """
# half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=timesteps.device
# )
# args = timesteps[:, None].float() * freqs[None, :]
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# if dim % 2:
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
# return embedding
#def a_get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
# assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
# half_dim = embedding_dim // 2
# magic number 10000 is from transformers
# emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
# emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = timesteps.float()[:, None] * emb[None, :]
# emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# if embedding_dim % 2 == 1: # zero pad
# emb = F.pad(emb, (0, 1), mode="constant")
# assert emb.shape == (timesteps.shape[0], embedding_dim)
# return emb
# unet_grad_tts.py # unet_grad_tts.py
class SinusoidalPosEmb(torch.nn.Module): class SinusoidalPosEmb(torch.nn.Module):
...@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module): ...@@ -70,26 +125,6 @@ class SinusoidalPosEmb(torch.nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb
# unet_ldm.py
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# unet_rl.py # unet_rl.py
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
...@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module): ...@@ -106,22 +141,6 @@ class SinusoidalPosEmb(nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb
# unet_sde_score_estimation.py
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode="constant")
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
# unet_sde_score_estimation.py # unet_sde_score_estimation.py
class GaussianFourierProjection(nn.Module): class GaussianFourierProjection(nn.Module):
......
...@@ -30,27 +30,28 @@ from tqdm import tqdm ...@@ -30,27 +30,28 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def get_timestep_embedding(timesteps, embedding_dim):
""" #def get_timestep_embedding(timesteps, embedding_dim):
This matches the implementation in Denoising Diffusion Probabilistic Models: # """
From Fairseq. # This matches the implementation in Denoising Diffusion Probabilistic Models:
Build sinusoidal embeddings. # From Fairseq.
This matches the implementation in tensor2tensor, but differs slightly # Build sinusoidal embeddings.
from the description in Section 3.5 of "Attention Is All You Need". # This matches the implementation in tensor2tensor, but differs slightly
""" # from the description in Section 3.5 of "Attention Is All You Need".
assert len(timesteps.shape) == 1 # """
# assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2 #
emb = math.log(10000) / (half_dim - 1) # half_dim = embedding_dim // 2
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) # emb = math.log(10000) / (half_dim - 1)
emb = emb.to(device=timesteps.device) # emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = timesteps.float()[:, None] * emb[None, :] # emb = emb.to(device=timesteps.device)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # emb = timesteps.float()[:, None] * emb[None, :]
if embedding_dim % 2 == 1: # zero pad # emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) # if embedding_dim % 2 == 1: # zero pad
return emb # emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# return emb
def nonlinearity(x): def nonlinearity(x):
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0): ...@@ -86,25 +87,25 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000): # def timestep_embedding(timesteps, dim, max_period=10000):
""" # """
Create sinusoidal timestep embeddings. # Create sinusoidal timestep embeddings.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element. # :param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional. # These may be fractional.
:param dim: the dimension of the output. # :param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings. # :param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings. # :return: an [N x dim] Tensor of positional embeddings.
""" # """
half = dim // 2 # half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device # device=timesteps.device
) # )
args = timesteps[:, None].float() * freqs[None] # args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: # if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding # return embedding
def zero_module(module): def zero_module(module):
...@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -627,7 +628,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
h = x.type(self.dtype) h = x.type(self.dtype)
for module in self.input_blocks: for module in self.input_blocks:
...@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel): ...@@ -714,7 +715,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
def forward(self, x, timesteps, transformer_out=None): def forward(self, x, timesteps, transformer_out=None):
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
# project the last token # project the last token
transformer_proj = self.transformer_proj(transformer_out[:, -1]) transformer_proj = self.transformer_proj(transformer_out[:, -1])
...@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -806,7 +807,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
x = torch.cat([x, upsampled], dim=1) x = torch.cat([x, upsampled], dim=1)
hs = [] hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
h = x h = x
for module in self.input_blocks: for module in self.input_blocks:
......
...@@ -16,6 +16,7 @@ except: ...@@ -16,6 +16,7 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def exists(val): def exists(val):
...@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0): ...@@ -316,34 +317,25 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000): #def timestep_embedding(timesteps, dim, max_period=10000):
""" # """
Create sinusoidal timestep embeddings. # Create sinusoidal timestep embeddings.
#
:param timesteps: a 1-D Tensor of N indices, one per batch element. # :param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional. # These may be fractional.
:param dim: the dimension of the output. # :param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings. # :param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings. # :return: an [N x dim] Tensor of positional embeddings.
""" # """
half = dim // 2 # half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device # device=timesteps.device
) # )
args = timesteps[:, None].float() * freqs[None] # args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: # if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding # return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
## go ## go
...@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1026,7 +1018,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
hs = [] hs = []
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
t_emb = timestep_embedding(timesteps, self.model_channels) t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
...@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module): ...@@ -1240,7 +1232,7 @@ class EncoderUNetModel(nn.Module):
:param timesteps: a 1-D batch of timesteps. :param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs. :return: an [N x K] Tensor of outputs.
""" """
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) emb = self.time_embed(get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0))
results = [] results = []
h = x.type(self.dtype) h = x.type(self.dtype)
......
...@@ -26,6 +26,7 @@ import torch.nn.functional as F ...@@ -26,6 +26,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -381,21 +382,21 @@ def get_act(nonlinearity): ...@@ -381,21 +382,21 @@ def get_act(nonlinearity):
raise NotImplementedError("activation function does not exist!") raise NotImplementedError("activation function does not exist!")
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): #def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 # assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2 # half_dim = embedding_dim // 2
# magic number 10000 is from transformers # magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1) # emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1) # emb = math.log(2.) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb = timesteps.float()[:, None] * emb[None, :] # emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad # if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode="constant") # emb = F.pad(emb, (0, 1), mode="constant")
assert emb.shape == (timesteps.shape[0], embedding_dim) # assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb # return emb
def default_init(scale=1.0): def default_init(scale=1.0):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment