Commit 6846ee2a authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finalize position embeddings

parents c7a39d38 9b9afc97
...@@ -23,7 +23,6 @@ def get_timestep_embedding( ...@@ -23,7 +23,6 @@ def get_timestep_embedding(
): ):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
Create sinusoidal timestep embeddings. Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element. :param timesteps: a 1-D Tensor of N indices, one per batch element.
......
...@@ -182,7 +182,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -182,7 +182,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale) t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)
if self.n_spks < 2: if self.n_spks < 2:
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# helpers functions # helpers functions
import functools import functools
import math
import string import string
import numpy as np import numpy as np
...@@ -26,7 +25,7 @@ import torch.nn.functional as F ...@@ -26,7 +25,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -418,18 +417,6 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -418,18 +417,6 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
return init return init
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
......
...@@ -679,6 +679,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -679,6 +679,7 @@ class PipelineTesterMixin(unittest.TestCase):
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
@unittest.skip("Skipping for now as it takes too long")
def test_ldm_text2img(self): def test_ldm_text2img(self):
model_id = "fusing/latent-diffusion-text2im-large" model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id) ldm = LatentDiffusionPipeline.from_pretrained(model_id)
...@@ -693,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -693,6 +694,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_ldm_text2img_fast(self):
model_id = "fusing/latent-diffusion-text2im-large"
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
def test_glide_text2img(self): def test_glide_text2img(self):
model_id = "fusing/glide-base" model_id = "fusing/glide-base"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment