Unverified Commit a6a25ceb authored by Akash Gokul's avatar Akash Gokul Committed by GitHub
Browse files

Fix Flax flip_sin_to_cos (#1369)



* Fix Flax flip_sin_to_cos

* Adding flip_sin_to_cos
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
parent b85bb075
...@@ -84,10 +84,11 @@ class FlaxTimesteps(nn.Module): ...@@ -84,10 +84,11 @@ class FlaxTimesteps(nn.Module):
Time step embedding dimension Time step embedding dimension
""" """
dim: int = 32 dim: int = 32
flip_sin_to_cos: bool = False
freq_shift: float = 1 freq_shift: float = 1
@nn.compact @nn.compact
def __call__(self, timesteps): def __call__(self, timesteps):
return get_sinusoidal_embeddings( return get_sinusoidal_embeddings(
timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
) )
...@@ -85,6 +85,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -85,6 +85,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The dimension of the cross attention features. The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks. Dropout probability for down, up and bottleneck blocks.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
""" """
sample_size: int = 32 sample_size: int = 32
...@@ -105,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -105,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
dropout: float = 0.0 dropout: float = 0.0
use_linear_projection: bool = False use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
flip_sin_to_cos: bool = True
freq_shift: int = 0 freq_shift: int = 0
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
...@@ -133,7 +138,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -133,7 +138,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
) )
# time # time
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_proj = FlaxTimesteps(
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
only_cross_attention = self.only_cross_attention only_cross_attention = self.only_cross_attention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment