Unverified Commit 0b61cea3 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] time embedding (#1081)

* initial get_sinusoidal_embeddings

* added asserts

* better var name

* fix docs
parent 33c48745
...@@ -17,23 +17,41 @@ import flax.linen as nn ...@@ -17,23 +17,41 @@ import flax.linen as nn
import jax.numpy as jnp import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but def get_sinusoidal_embeddings(
# less general (only handles the case we currently need). timesteps: jnp.ndarray,
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): embedding_dim: int,
""" freq_shift: float = 1,
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. min_timescale: float = 1,
max_timescale: float = 1.0e4,
:param timesteps: a 1-D tensor of N indices, one per batch element. flip_sin_to_cos: bool = False,
scale: float = 1.0,
) -> jnp.ndarray:
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional. These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embedding_dim: The number of output channels.
embeddings. :return: an [N x dim] tensor of positional embeddings. min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
Returns:
a Tensor of timing signals [N, num_channels]
""" """
half_dim = embedding_dim // 2 assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
emb = math.log(10000) / (half_dim - freq_shift) assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
emb = jnp.exp(jnp.arange(half_dim) * -emb) num_timescales = float(embedding_dim // 2)
emb = timesteps[:, None] * emb[None, :] log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
return emb emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
# scale embeddings
scaled_time = scale * emb
if flip_sin_to_cos:
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
else:
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
return signal
class FlaxTimestepEmbedding(nn.Module): class FlaxTimestepEmbedding(nn.Module):
...@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module): ...@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module):
@nn.compact @nn.compact
def __call__(self, timesteps): def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment