Unverified Commit 0b61cea3 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] time embedding (#1081)

* initial get_sinusoidal_embeddings

* added asserts

* better var name

* fix docs
parent 33c48745
...@@ -17,23 +17,41 @@ import flax.linen as nn ...@@ -17,23 +17,41 @@ import flax.linen as nn
import jax.numpy as jnp import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but def get_sinusoidal_embeddings(
# less general (only handles the case we currently need). timesteps: jnp.ndarray,
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): embedding_dim: int,
freq_shift: float = 1,
min_timescale: float = 1,
max_timescale: float = 1.0e4,
flip_sin_to_cos: bool = False,
scale: float = 1.0,
) -> jnp.ndarray:
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
Returns:
a Tensor of timing signals [N, num_channels]
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
num_timescales = float(embedding_dim // 2)
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
:param timesteps: a 1-D tensor of N indices, one per batch element. # scale embeddings
These may be fractional. scaled_time = scale * emb
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings. if flip_sin_to_cos:
""" signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
half_dim = embedding_dim // 2 else:
emb = math.log(10000) / (half_dim - freq_shift) signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
emb = jnp.exp(jnp.arange(half_dim) * -emb) signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
emb = timesteps[:, None] * emb[None, :] return signal
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb
class FlaxTimestepEmbedding(nn.Module): class FlaxTimestepEmbedding(nn.Module):
...@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module): ...@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module):
@nn.compact @nn.compact
def __call__(self, timesteps): def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment