"vscode:/vscode.git/clone" did not exist on "9c0944581a386736bc808e68d7dfb52d8cf1790e"
Unverified Commit a3e8d3f7 authored by wony617's avatar wony617 Committed by GitHub
Browse files

[docs] refactoring docstrings in `models/embeddings_flax.py` (#9592)



* [docs] refactoring docstrings in `models/embeddings_flax.py`

* Update src/diffusers/models/embeddings_flax.py

* make style

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent fff4be8e
......@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
"""Returns the positional encoding (same as Tensor2Tensor).
Args:
timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
embedding_dim: The number of output channels.
min_timescale: The smallest time unit (should probably be 0.0).
max_timescale: The largest time unit.
timesteps (`jnp.ndarray` of shape `(N,)`):
A 1-D array of N indices, one per batch element. These may be fractional.
embedding_dim (`int`):
The number of output channels.
freq_shift (`float`, *optional*, defaults to `1`):
Shift applied to the frequency scaling of the embeddings.
min_timescale (`float`, *optional*, defaults to `1`):
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
max_timescale (`float`, *optional*, defaults to `1.0e4`):
The largest time unit used in the sinusoidal calculation.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the order of sinusoidal components to cosine first.
scale (`float`, *optional*, defaults to `1.0`):
A scaling factor applied to the positional embeddings.
Returns:
a Tensor of timing signals [N, num_channels]
"""
......@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
Time step embedding dimension.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The data type for the embedding parameters.
"""
time_embed_dim: int = 32
......@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
Time step embedding dimension.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sinusoidal function from sine to cosine.
freq_shift (`float`, *optional*, defaults to `1`):
Frequency shift applied to the sinusoidal embeddings.
"""
dim: int = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment