Unverified Commit a3e8d3f7 authored by wony617's avatar wony617 Committed by GitHub
Browse files

[docs] refactoring docstrings in `models/embeddings_flax.py` (#9592)



* [docs] refactoring docstrings in `models/embeddings_flax.py`

* Update src/diffusers/models/embeddings_flax.py

* make style

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent fff4be8e
...@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings( ...@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
"""Returns the positional encoding (same as Tensor2Tensor). """Returns the positional encoding (same as Tensor2Tensor).
Args: Args:
timesteps: a 1-D Tensor of N indices, one per batch element. timesteps (`jnp.ndarray` of shape `(N,)`):
These may be fractional. A 1-D array of N indices, one per batch element. These may be fractional.
embedding_dim: The number of output channels. embedding_dim (`int`):
min_timescale: The smallest time unit (should probably be 0.0). The number of output channels.
max_timescale: The largest time unit. freq_shift (`float`, *optional*, defaults to `1`):
Shift applied to the frequency scaling of the embeddings.
min_timescale (`float`, *optional*, defaults to `1`):
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
max_timescale (`float`, *optional*, defaults to `1.0e4`):
The largest time unit used in the sinusoidal calculation.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the order of sinusoidal components to cosine first.
scale (`float`, *optional*, defaults to `1.0`):
A scaling factor applied to the positional embeddings.
Returns: Returns:
a Tensor of timing signals [N, num_channels] a Tensor of timing signals [N, num_channels]
""" """
...@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module): ...@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
Args: Args:
time_embed_dim (`int`, *optional*, defaults to `32`): time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension Time step embedding dimension.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
Parameters `dtype` The data type for the embedding parameters.
""" """
time_embed_dim: int = 32 time_embed_dim: int = 32
...@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module): ...@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
Args: Args:
dim (`int`, *optional*, defaults to `32`): dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension Time step embedding dimension.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sinusoidal function from sine to cosine.
freq_shift (`float`, *optional*, defaults to `1`):
Frequency shift applied to the sinusoidal embeddings.
""" """
dim: int = 32 dim: int = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment