embeddings_flax.py 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import flax.linen as nn
import jax.numpy as jnp


# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
22
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
23
24
25
26
27
28
29
30
31
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D tensor of N indices, one per batch element.
                      These may be fractional.
    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
    embeddings. :return: an [N x dim] tensor of positional embeddings.
    """
    half_dim = embedding_dim // 2
32
    emb = math.log(10000) / (half_dim - freq_shift)
33
34
35
36
37
38
39
    emb = jnp.exp(jnp.arange(half_dim) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
    return emb


class FlaxTimestepEmbedding(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
40
41
42
43
44
45
46
47
48
    r"""
    Time step Embedding Module. Learns embeddings for input time steps.

    Args:
        time_embed_dim (`int`, *optional*, defaults to `32`):
                Time step embedding dimension
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
                Parameters `dtype`
    """
49
50
51
52
53
54
55
56
57
58
59
60
    time_embed_dim: int = 32
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, temb):
        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
        temb = nn.silu(temb)
        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
        return temb


class FlaxTimesteps(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
61
62
63
64
65
66
67
    r"""
    Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239

    Args:
        dim (`int`, *optional*, defaults to `32`):
                Time step embedding dimension
    """
68
    dim: int = 32
69
    freq_shift: float = 1
70
71
72

    @nn.compact
    def __call__(self, timesteps):
73
        return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)