embeddings.py 6.38 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
import math
Patrick von Platen's avatar
Patrick von Platen committed
15

16
17
import numpy as np
import torch
18
from torch import nn
Patrick von Platen's avatar
Patrick von Platen committed
19

20

21
def get_timestep_embedding(
Kashif Rasul's avatar
Kashif Rasul committed
22
23
24
25
26
27
    timesteps: torch.Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
28
):
Patrick von Platen's avatar
Patrick von Platen committed
29
    """
Patrick von Platen's avatar
Patrick von Platen committed
30
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
31
32
33

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
Patrick von Platen's avatar
Patrick von Platen committed
34
35
    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
    embeddings. :return: an [N x dim] Tensor of positional embeddings.
Patrick von Platen's avatar
Patrick von Platen committed
36
    """
37
    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
Patrick von Platen's avatar
Patrick von Platen committed
38
39

    half_dim = embedding_dim // 2
40
41
42
    exponent = -math.log(max_period) * torch.arange(
        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
    )
43
    exponent = exponent / (half_dim - downscale_freq_shift)
44

45
    emb = torch.exp(exponent)
46
47
    emb = timesteps[:, None].float() * emb[None, :]

48
49
50
    # scale embeddings
    emb = scale * emb

51
    # concat sine and cosine embeddings
52
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
53

54
    # flip sine and cosine embeddings
55
56
57
58
59
    if flip_sin_to_cos:
        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
Patrick von Platen's avatar
Patrick von Platen committed
60
61
62
63
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


64
class TimestepEmbedding(nn.Module):
65
    def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
66
67
        super().__init__()

68
        self.linear_1 = nn.Linear(in_channels, time_embed_dim)
69
70
71
        self.act = None
        if act_fn == "silu":
            self.act = nn.SiLU()
72
73
74
75
76
77
78
79
        elif act_fn == "mish":
            self.act = nn.Mish()

        if out_dim is not None:
            time_embed_dim_out = out_dim
        else:
            time_embed_dim_out = time_embed_dim
        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
80
81
82
83
84
85
86
87
88
89
90
91

    def forward(self, sample):
        sample = self.linear_1(sample)

        if self.act is not None:
            sample = self.act(sample)

        sample = self.linear_2(sample)
        return sample


class Timesteps(nn.Module):
Kashif Rasul's avatar
Kashif Rasul committed
92
    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        super().__init__()
        self.num_channels = num_channels
        self.flip_sin_to_cos = flip_sin_to_cos
        self.downscale_freq_shift = downscale_freq_shift

    def forward(self, timesteps):
        t_emb = get_timestep_embedding(
            timesteps,
            self.num_channels,
            flip_sin_to_cos=self.flip_sin_to_cos,
            downscale_freq_shift=self.downscale_freq_shift,
        )
        return t_emb


108
109
class GaussianFourierProjection(nn.Module):
    """Gaussian Fourier embeddings for noise levels."""
Patrick von Platen's avatar
Patrick von Platen committed
110

111
112
113
    def __init__(
        self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
    ):
114
        super().__init__()
115
        self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
116
117
        self.log = log
        self.flip_sin_to_cos = flip_sin_to_cos
118

119
120
121
        if set_W_to_weight:
            # to delete later
            self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
122

123
            self.weight = self.W
124

125
    def forward(self, x):
126
127
128
        if self.log:
            x = torch.log(x)

129
        x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
130
131
132
133
134

        if self.flip_sin_to_cos:
            out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
        else:
            out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
135
        return out
Will Berman's avatar
Will Berman committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200


class ImagePositionalEmbeddings(nn.Module):
    """
    Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
    height and width of the latent space.

    For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092

    For VQ-diffusion:

    Output vector embeddings are used as input for the transformer.

    Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.

    Args:
        num_embed (`int`):
            Number of embeddings for the latent pixels embeddings.
        height (`int`):
            Height of the latent image i.e. the number of height embeddings.
        width (`int`):
            Width of the latent image i.e. the number of width embeddings.
        embed_dim (`int`):
            Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
    """

    def __init__(
        self,
        num_embed: int,
        height: int,
        width: int,
        embed_dim: int,
    ):
        super().__init__()

        self.height = height
        self.width = width
        self.num_embed = num_embed
        self.embed_dim = embed_dim

        self.emb = nn.Embedding(self.num_embed, embed_dim)
        self.height_emb = nn.Embedding(self.height, embed_dim)
        self.width_emb = nn.Embedding(self.width, embed_dim)

    def forward(self, index):
        emb = self.emb(index)

        height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))

        # 1 x H x D -> 1 x H x 1 x D
        height_emb = height_emb.unsqueeze(2)

        width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))

        # 1 x W x D -> 1 x 1 x W x D
        width_emb = width_emb.unsqueeze(1)

        pos_emb = height_emb + width_emb

        # 1 x H x W x D -> 1 x L xD
        pos_emb = pos_emb.view(1, self.height * self.width, -1)

        emb = emb + pos_emb[:, : emb.shape[1], :]

        return emb