embeddings_flax.py 4.71 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import flax.linen as nn
import jax.numpy as jnp

Sayak Paul's avatar
Sayak Paul committed
19
20
21
22
23
from ..utils import logging


logger = logging.get_logger(__name__)

24

Kashif Rasul's avatar
Kashif Rasul committed
25
26
27
28
29
30
31
32
33
34
def get_sinusoidal_embeddings(
    timesteps: jnp.ndarray,
    embedding_dim: int,
    freq_shift: float = 1,
    min_timescale: float = 1,
    max_timescale: float = 1.0e4,
    flip_sin_to_cos: bool = False,
    scale: float = 1.0,
) -> jnp.ndarray:
    """Returns the positional encoding (same as Tensor2Tensor).
35

Kashif Rasul's avatar
Kashif Rasul committed
36
    Args:
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        timesteps (`jnp.ndarray` of shape `(N,)`):
            A 1-D array of N indices, one per batch element. These may be fractional.
        embedding_dim (`int`):
            The number of output channels.
        freq_shift (`float`, *optional*, defaults to `1`):
            Shift applied to the frequency scaling of the embeddings.
        min_timescale (`float`, *optional*, defaults to `1`):
            The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
        max_timescale (`float`, *optional*, defaults to `1.0e4`):
            The largest time unit used in the sinusoidal calculation.
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
            Whether to flip the order of sinusoidal components to cosine first.
        scale (`float`, *optional*, defaults to `1.0`):
            A scaling factor applied to the positional embeddings.

Kashif Rasul's avatar
Kashif Rasul committed
52
53
    Returns:
        a Tensor of timing signals [N, num_channels]
54
    """
Kashif Rasul's avatar
Kashif Rasul committed
55
56
57
58
59
60
    assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
    assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
    num_timescales = float(embedding_dim // 2)
    log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
    inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
    emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
61

Kashif Rasul's avatar
Kashif Rasul committed
62
63
64
65
66
67
68
69
70
    # scale embeddings
    scaled_time = scale * emb

    if flip_sin_to_cos:
        signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
    else:
        signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
    signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
    return signal
71
72
73


class FlaxTimestepEmbedding(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
74
75
76
77
78
    r"""
    Time step Embedding Module. Learns embeddings for input time steps.

    Args:
        time_embed_dim (`int`, *optional*, defaults to `32`):
79
80
81
            Time step embedding dimension.
        dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
            The data type for the embedding parameters.
Younes Belkada's avatar
Younes Belkada committed
82
    """
83

Sayak Paul's avatar
Sayak Paul committed
84
85
86
87
88
    logger.warning(
        "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
        "recommend migrating to PyTorch classes or pinning your version of Diffusers."
    )

89
90
91
92
93
94
95
96
97
98
99
100
    time_embed_dim: int = 32
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, temb):
        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
        temb = nn.silu(temb)
        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
        return temb


class FlaxTimesteps(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
101
    r"""
Quentin Gallouédec's avatar
Quentin Gallouédec committed
102
    Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
Younes Belkada's avatar
Younes Belkada committed
103
104
105

    Args:
        dim (`int`, *optional*, defaults to `32`):
106
107
108
109
110
            Time step embedding dimension.
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
            Whether to flip the sinusoidal function from sine to cosine.
        freq_shift (`float`, *optional*, defaults to `1`):
            Frequency shift applied to the sinusoidal embeddings.
Younes Belkada's avatar
Younes Belkada committed
111
    """
112

113
    dim: int = 32
Akash Gokul's avatar
Akash Gokul committed
114
    flip_sin_to_cos: bool = False
115
    freq_shift: float = 1
116

Sayak Paul's avatar
Sayak Paul committed
117
118
119
120
121
    logger.warning(
        "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
        "recommend migrating to PyTorch classes or pinning your version of Diffusers."
    )

122
123
    @nn.compact
    def __call__(self, timesteps):
124
        return get_sinusoidal_embeddings(
Akash Gokul's avatar
Akash Gokul committed
125
            timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
126
        )